diff options
30 files changed, 2308 insertions, 314 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 3e66f9cbb..5812085fa 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -942,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1305,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 3bac4d90d..b5a72ce63 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -531,7 +531,7 @@ func SetSockOpt(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy return 0, nil, syserror.ENOTSOCK } - if optLen <= 0 { + if optLen < 0 { return 0, nil, syserror.EINVAL } if optLen > maxOptLen { diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go index 315780c0c..40e202717 100644 --- a/pkg/tcpip/ports/ports.go +++ b/pkg/tcpip/ports/ports.go @@ -47,43 +47,76 @@ type portNode struct { refs int } -// bindAddresses is a set of IP addresses. -type bindAddresses map[tcpip.Address]portNode - -// isAvailable checks whether an IP address is available to bind to. -func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool) bool { - if addr == anyIPAddress { - if len(b) == 0 { - return true - } +// deviceNode is never empty. When it has no elements, it is removed from the +// map that references it. +type deviceNode map[tcpip.NICID]portNode + +// isAvailable checks whether binding is possible by device. If not binding to a +// device, check against all portNodes. If binding to a specific device, check +// against the unspecified device and the provided device. +func (d deviceNode) isAvailable(reuse bool, bindToDevice tcpip.NICID) bool { + if bindToDevice == 0 { + // Trying to binding all devices. if !reuse { + // Can't bind because the (addr,port) is already bound. return false } - for _, n := range b { - if !n.reuse { + for _, p := range d { + if !p.reuse { + // Can't bind because the (addr,port) was previously bound without reuse. return false } } return true } - // If all addresses for this portDescriptor are already bound, no - // address is available. - if n, ok := b[anyIPAddress]; ok { - if !reuse { + if p, ok := d[0]; ok { + if !reuse || !p.reuse { return false } - if !n.reuse { + } + + if p, ok := d[bindToDevice]; ok { + if !reuse || !p.reuse { return false } } - if n, ok := b[addr]; ok { - if !reuse { + return true +} + +// bindAddresses is a set of IP addresses. +type bindAddresses map[tcpip.Address]deviceNode + +// isAvailable checks whether an IP address is available to bind to. If the +// address is the "any" address, check all other addresses. Otherwise, just +// check against the "any" address and the provided address. +func (b bindAddresses) isAvailable(addr tcpip.Address, reuse bool, bindToDevice tcpip.NICID) bool { + if addr == anyIPAddress { + // If binding to the "any" address then check that there are no conflicts + // with all addresses. + for _, d := range b { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + return true + } + + // Check that there is no conflict with the "any" address. + if d, ok := b[anyIPAddress]; ok { + if !d.isAvailable(reuse, bindToDevice) { + return false + } + } + + // Check that this is no conflict with the provided address. + if d, ok := b[addr]; ok { + if !d.isAvailable(reuse, bindToDevice) { return false } - return n.reuse } + return true } @@ -116,17 +149,17 @@ func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Er } // IsPortAvailable tests if the given port is available on all given protocols. -func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) IsPortAvailable(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { s.mu.Lock() defer s.mu.Unlock() - return s.isPortAvailableLocked(networks, transport, addr, port, reuse) + return s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) } -func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { +func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { for _, network := range networks { desc := portDescriptor{network, transport, port} if addrs, ok := s.allocatedPorts[desc]; ok { - if !addrs.isAvailable(addr, reuse) { + if !addrs.isAvailable(addr, reuse, bindToDevice) { return false } } @@ -138,14 +171,14 @@ func (s *PortManager) isPortAvailableLocked(networks []tcpip.NetworkProtocolNumb // reserved by another endpoint. If port is zero, ReservePort will search for // an unreserved ephemeral port and reserve it, returning its value in the // "port" return value. -func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) (reservedPort uint16, err *tcpip.Error) { +func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) (reservedPort uint16, err *tcpip.Error) { s.mu.Lock() defer s.mu.Unlock() // If a port is specified, just try to reserve it for all network // protocols. if port != 0 { - if !s.reserveSpecificPort(networks, transport, addr, port, reuse) { + if !s.reserveSpecificPort(networks, transport, addr, port, reuse, bindToDevice) { return 0, tcpip.ErrPortInUse } return port, nil @@ -153,13 +186,13 @@ func (s *PortManager) ReservePort(networks []tcpip.NetworkProtocolNumber, transp // A port wasn't specified, so try to find one. return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { - return s.reserveSpecificPort(networks, transport, addr, p, reuse), nil + return s.reserveSpecificPort(networks, transport, addr, p, reuse, bindToDevice), nil }) } // reserveSpecificPort tries to reserve the given port on all given protocols. -func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool) bool { - if !s.isPortAvailableLocked(networks, transport, addr, port, reuse) { +func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, reuse bool, bindToDevice tcpip.NICID) bool { + if !s.isPortAvailableLocked(networks, transport, addr, port, reuse, bindToDevice) { return false } @@ -171,11 +204,16 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber m = make(bindAddresses) s.allocatedPorts[desc] = m } - if n, ok := m[addr]; ok { + d, ok := m[addr] + if !ok { + d = make(deviceNode) + m[addr] = d + } + if n, ok := d[bindToDevice]; ok { n.refs++ - m[addr] = n + d[bindToDevice] = n } else { - m[addr] = portNode{reuse: reuse, refs: 1} + d[bindToDevice] = portNode{reuse: reuse, refs: 1} } } @@ -184,22 +222,28 @@ func (s *PortManager) reserveSpecificPort(networks []tcpip.NetworkProtocolNumber // ReleasePort releases the reservation on a port/IP combination so that it can // be reserved by other endpoints. -func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) { +func (s *PortManager) ReleasePort(networks []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16, bindToDevice tcpip.NICID) { s.mu.Lock() defer s.mu.Unlock() for _, network := range networks { desc := portDescriptor{network, transport, port} if m, ok := s.allocatedPorts[desc]; ok { - n, ok := m[addr] + d, ok := m[addr] + if !ok { + continue + } + n, ok := d[bindToDevice] if !ok { continue } n.refs-- + d[bindToDevice] = n if n.refs == 0 { + delete(d, bindToDevice) + } + if len(d) == 0 { delete(m, addr) - } else { - m[addr] = n } if len(m) == 0 { delete(s.allocatedPorts, desc) diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go index 689401661..a67e283f1 100644 --- a/pkg/tcpip/ports/ports_test.go +++ b/pkg/tcpip/ports/ports_test.go @@ -34,6 +34,7 @@ type portReserveTestAction struct { want *tcpip.Error reuse bool release bool + device tcpip.NICID } func TestPortReservation(t *testing.T) { @@ -100,6 +101,112 @@ func TestPortReservation(t *testing.T) { {port: 24, ip: anyIPAddress, release: true}, {port: 24, ip: anyIPAddress, reuse: false, want: nil}, }, + }, { + tname: "bind twice with device fails", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 3, want: nil}, + {port: 24, ip: fakeIPAddress, device: 3, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 1, want: nil}, + {port: 24, ip: fakeIPAddress, device: 2, want: nil}, + }, + }, { + tname: "bind to device and then without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind without device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, want: nil}, + {port: 24, ip: fakeIPAddress, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind with reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + }, + }, { + tname: "binding with reuse and device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 123, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "mixing reuse and not reuse by binding to device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 999, want: nil}, + }, + }, { + tname: "can't bind to 0 after mixing reuse and not reuse", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "bind and release", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: tcpip.ErrPortInUse}, + {port: 24, ip: fakeIPAddress, device: 789, reuse: true, want: nil}, + + // Release the bind to device 0 and try again. + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil}, + }, + }, { + tname: "bind twice with reuse once", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 0, reuse: true, want: tcpip.ErrPortInUse}, + }, + }, { + tname: "release an unreserved device", + actions: []portReserveTestAction{ + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil}, + // The below don't exist. + {port: 24, ip: fakeIPAddress, device: 345, reuse: false, want: nil, release: true}, + {port: 9999, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + // Release all. + {port: 24, ip: fakeIPAddress, device: 123, reuse: false, want: nil, release: true}, + {port: 24, ip: fakeIPAddress, device: 456, reuse: false, want: nil, release: true}, + }, }, } { t.Run(test.tname, func(t *testing.T) { @@ -108,12 +215,12 @@ func TestPortReservation(t *testing.T) { for _, test := range test.actions { if test.release { - pm.ReleasePort(net, fakeTransNumber, test.ip, test.port) + pm.ReleasePort(net, fakeTransNumber, test.ip, test.port, test.device) continue } - gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse) + gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port, test.reuse, test.device) if err != test.want { - t.Fatalf("ReservePort(.., .., %s, %d, %t) = %v, want %v", test.ip, test.port, test.release, err, test.want) + t.Fatalf("ReservePort(.., .., %s, %d, %t, %d) = %v, want %v", test.ip, test.port, test.reuse, test.device, err, test.want) } if test.port == 0 && (gotPort == 0 || gotPort < FirstEphemeral) { t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, FirstEphemeral) diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 28c49e8ff..3842f1f7d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -54,6 +54,7 @@ go_test( size = "small", srcs = [ "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ @@ -64,6 +65,9 @@ go_test( "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 0e8a23f00..f6106f762 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,8 +34,6 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - mu sync.RWMutex spoofing bool promiscuous bool @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -707,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -723,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -767,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 18d1704a5..6a8079823 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1033,73 +1033,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..8c768c299 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,25 +35,109 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + } + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + mpep.handlePacketAll(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { - return - } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return } delete(eps.endpoints, id) } @@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. type multiPortEndpoint struct { mu sync.RWMutex endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool } // reciprocalScale scales a value into range [0, n). @@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy except for + // the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) -} - -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a broadcast, then find all matching transport endpoints. // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) + destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { @@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.handlePacket(r, id, vv) } return true @@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, vv) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { return ep diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..210233dc0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicid := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 56e8a5d9b..842a16277 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -127,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -168,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c021c67ac..faaa4a4e3 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -495,6 +495,10 @@ type ReuseAddressOption int // to be bound to an identical socket address. type ReusePortOption int +// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets +// should bind only on a specific NIC. +type BindToDeviceOption string + // QuickAckOption is stubbed out in SetSockOpt/GetSockOpt. type QuickAckOption int diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index a111fdb2a..a3a910d41 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -104,7 +104,7 @@ func (e *endpoint) Close() { e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite switch e.state { case stateBound, stateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e, 0 /* bindToDevice */) } // Close the receive list and drain it. @@ -543,14 +543,14 @@ func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.Networ if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindToDevice */) return id, err } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false /* reuse */, 0 /* bindtodevice */) switch err { case nil: return true, nil diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 0802e984e..3ae4a5426 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -242,7 +242,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i n.initGSO() // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort, n.bindToDevice); err != nil { n.Close() return nil, err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 35b489c68..a1cd0d481 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -280,6 +280,9 @@ type endpoint struct { // reusePort is set to true if SO_REUSEPORT is enabled. reusePort bool + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + // delay enables Nagle's algorithm. // // delay is a boolean (0 is false) and must be accessed atomically. @@ -564,11 +567,11 @@ func (e *endpoint) Close() { // in Listen() when trying to register. if e.state == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -625,12 +628,12 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.isRegistered = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1060,6 +1063,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.QuickAckOption: if v == 0 { atomic.StoreUint32(&e.slowAck, 1) @@ -1260,6 +1278,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = "" + return nil + case *tcpip.QuickAckOption: *o = 1 if v := atomic.LoadUint32(&e.slowAck); v != 0 { @@ -1466,7 +1494,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if e.id.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1480,13 +1508,15 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er if sameAddr && p == e.id.RemotePort { return false, nil } - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + // reusePort is false below because connect cannot reuse a port even if + // reusePort was set. + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false /* reusePort */, e.bindToDevice) { return false, nil } id := e.id id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) { case nil: e.id = id return true, nil @@ -1504,7 +1534,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // before Connect: in such a case we don't want to hold on to // reservations anymore. if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.bindToDevice) e.isPortReserved = false } @@ -1648,7 +1678,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort, e.bindToDevice); err != nil { return err } @@ -1729,7 +1759,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort, e.bindToDevice) if err != nil { return err } @@ -1739,16 +1769,16 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { e.id.LocalPort = port // Any failures beyond this point must remove the port registration. - defer func() { + defer func(bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, bindToDevice) e.isPortReserved = false e.effectiveNetProtos = nil e.id.LocalPort = 0 e.id.LocalAddress = "" e.boundNICID = 0 } - }() + }(e.bindToDevice) // If an address is specified, we must ensure that it's one of our // local addresses. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 2be094876..089826a88 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -465,6 +465,66 @@ func TestSimpleReceive(t *testing.T) { ) } +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device string + want tcp.EndpointState + }{ + {"RightDevice", "nic1", tcp.StateEstablished}, + {"WrongDevice", "nic2", tcp.StateSynSent}, + {"AnyDevice", "", tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + }) + } +} + func TestOutOfOrderReceive(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -2970,6 +3030,62 @@ func TestMinMaxBufferSizes(t *testing.T) { checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) } +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } + + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) + } + + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s + } + + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) + } +} + func makeStack() (*stack.Stack, *tcpip.Error) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocol{ diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index d3f1d2cdf..ef823e4ae 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -158,7 +158,14 @@ func New(t *testing.T, mtu uint32) *Context { if testing.Verbose() { wep = sniffer.New(ep) } - if err := s.CreateNIC(1, wep); err != nil { + if err := s.CreateNamedNIC(1, "nic1", wep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + if err := s.CreateNamedNIC(2, "nic2", wep2); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -588,12 +595,8 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) c.Port = tcpHdr.SourcePort() } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. -// -// It also sets the receive buffer for the endpoint to the specified -// value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { // Create TCP endpoint. var err *tcpip.Error c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) @@ -606,6 +609,15 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.t.Fatalf("SetSockOpt failed failed: %v", err) } } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) c.Connect(iss, rcvWnd, options) } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index c1ca22b35..7a635ab8d 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -52,6 +52,7 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0bec7e62d..52f5af777 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -88,6 +88,7 @@ type endpoint struct { multicastNICID tcpip.NICID multicastLoop bool reusePort bool + bindToDevice tcpip.NICID broadcast bool // shutdownFlags represent the current shutdown state of the endpoint. @@ -144,8 +145,8 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } for _, mem := range e.multicastMemberships { @@ -551,6 +552,21 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.reusePort = v != 0 e.mu.Unlock() + case tcpip.BindToDeviceOption: + e.mu.Lock() + defer e.mu.Unlock() + if v == "" { + e.bindToDevice = 0 + return nil + } + for nicid, nic := range e.stack.NICInfo() { + if nic.Name == string(v) { + e.bindToDevice = nicid + return nil + } + } + return tcpip.ErrUnknownDevice + case tcpip.BroadcastOption: e.mu.Lock() e.broadcast = v != 0 @@ -646,6 +662,16 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.BindToDeviceOption: + e.mu.RLock() + defer e.mu.RUnlock() + if nic, ok := e.stack.NICInfo()[e.bindToDevice]; ok { + *o = tcpip.BindToDeviceOption(nic.Name) + return nil + } + *o = tcpip.BindToDeviceOption("") + return nil + case *tcpip.KeepaliveEnabledOption: *o = 0 return nil @@ -753,12 +779,12 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.id.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort, e.bindToDevice) } e.state = StateInitial } - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) e.id = id e.route.Release() e.route = stack.Route{} @@ -835,7 +861,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Remove the old registration. if e.id.LocalPort != 0 { - e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.bindToDevice) } e.id = id @@ -898,16 +924,16 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { if e.id.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort, e.bindToDevice) if err != nil { return id, err } id.LocalPort = port } - err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.bindToDevice) } return id, err } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index a9edc2c8d..2d0bc5221 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -74,7 +74,7 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort, ep.bindToDevice); err != nil { ep.Close() return nil, err } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 2ec27be4d..5059ca22d 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -17,7 +17,6 @@ package udp_test import ( "bytes" "fmt" - "math" "math/rand" "testing" "time" @@ -27,6 +26,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" @@ -476,87 +476,59 @@ func newMinPayload(minSize int) []byte { return b } -func TestBindPortReuse(t *testing.T) { - c := newDualTestContext(t, defaultMTU) - defer c.cleanup() - - c.createEndpoint(ipv6.ProtocolNumber) - - var eps [5]tcpip.Endpoint - reusePortOpt := tcpip.ReusePortOption(1) - - pollChannel := make(chan tcpip.Endpoint) - for i := 0; i < len(eps); i++ { - // Try to receive the data. - wq := waiter.Queue{} - we, ch := waiter.NewChannelEntry(nil) - wq.EventRegister(&we, waiter.EventIn) - defer wq.EventUnregister(&we) - defer close(ch) - - var err *tcpip.Error - eps[i], err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - - go func(ep tcpip.Endpoint) { - for range ch { - pollChannel <- ep - } - }(eps[i]) +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) - defer eps[i].Close() - if err := eps[i].SetSockOpt(reusePortOpt); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - if err := eps[i].Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) - } + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) } + defer ep.Close() - npackets := 100000 - nports := 10000 - ports := make(map[uint16]tcpip.Endpoint) - stats := make(map[tcpip.Endpoint]int) - for i := 0; i < npackets; i++ { - // Send a packet. - port := uint16(i % nports) - payload := newPayload() - c.injectV6Packet(payload, &header4Tuple{ - srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort + port}, - dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, - }) + if err := s.CreateNamedNIC(321, "my_device", loopback.New()); err != nil { + t.Errorf("CreateNamedNIC failed: %v", err) + } - var addr tcpip.FullAddress - ep := <-pollChannel - _, _, err := ep.Read(&addr) - if err != nil { - c.t.Fatalf("Read failed: %v", err) - } - stats[ep]++ - if i < nports { - ports[uint16(i)] = ep - } else { - // Check that all packets from one client are handled - // by the same socket. - if ports[port] != ep { - t.Fatalf("Port mismatch") - } - } + // Make an nameless NIC. + if err := s.CreateNIC(54321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %v", err) } - if len(stats) != len(eps) { - t.Fatalf("Only %d(expected %d) sockets received packets", len(stats), len(eps)) + // strPtr is used instead of taking the address of string literals, which is + // a compiler error. + strPtr := func(s string) *string { + return &s } - // Check that a packet distribution is fair between sockets. - for _, c := range stats { - n := float64(npackets) / float64(len(eps)) - // The deviation is less than 10%. - if math.Abs(float64(c)-n) > n/10 { - t.Fatal(c, n) - } + testActions := []struct { + name string + setBindToDevice *string + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, ""}, + {"BindToNonExistent", strPtr("non_existent_device"), tcpip.ErrUnknownDevice, ""}, + {"BindToExistent", strPtr("my_device"), nil, "my_device"}, + {"UnbindToDevice", strPtr(""), nil, ""}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if got, want := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; got != want { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, got, want) + } + } + bindToDevice := tcpip.BindToDeviceOption("to be modified by GetSockOpt") + if ep.GetSockOpt(&bindToDevice) != nil { + t.Errorf("GetSockOpt got %v, want %v", ep.GetSockOpt(&bindToDevice), nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %q, want %q", got, want) + } + }) } } diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 28b23ce58..e645eebfa 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -2464,6 +2464,63 @@ cc_binary( ) cc_binary( + name = "socket_bind_to_device_test", + testonly = 1, + srcs = [ + "socket_bind_to_device.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_sequence_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_sequence.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( + name = "socket_bind_to_device_distribution_test", + testonly = 1, + srcs = [ + "socket_bind_to_device_distribution.cc", + ], + linkstatic = 1, + deps = [ + ":ip_socket_test_util", + ":socket_bind_to_device_util", + ":socket_test_util", + "//test/util:capability_util", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:thread_util", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "socket_ip_udp_loopback_non_blocking_test", testonly = 1, srcs = [ @@ -2740,6 +2797,23 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "socket_bind_to_device_util", + testonly = 1, + srcs = [ + "socket_bind_to_device_util.cc", + ], + hdrs = [ + "socket_bind_to_device_util.h", + ], + deps = [ + "//test/util:test_util", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + cc_binary( name = "socket_stream_local_test", testonly = 1, @@ -3253,6 +3327,7 @@ cc_binary( "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", + "//test/util:uid_util", "@com_google_absl//absl/flags:flag", "@com_google_absl//absl/strings", "@com_google_googletest//:gtest", diff --git a/test/syscalls/linux/socket_bind_to_device.cc b/test/syscalls/linux/socket_bind_to_device.cc new file mode 100644 index 000000000..d20821cac --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device.cc @@ -0,0 +1,314 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +// Test fixture for SO_BINDTODEVICE tests. +class BindToDeviceTest : public ::testing::TestWithParam<SocketKind> { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + + interface_name_ = "eth1"; + auto interface_names = GetInterfaceNames(); + if (interface_names.find(interface_name_) == interface_names.end()) { + // Need a tunnel. + tunnel_ = ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New()); + interface_name_ = tunnel_->GetName(); + ASSERT_FALSE(interface_name_.empty()); + } + socket_ = ASSERT_NO_ERRNO_AND_VALUE(GetParam().Create()); + } + + string interface_name() const { return interface_name_; } + + int socket_fd() const { return socket_->get(); } + + private: + std::unique_ptr<Tunnel> tunnel_; + string interface_name_; + std::unique_ptr<FileDescriptor> socket_; +}; + +constexpr char kIllegalIfnameChar = '/'; + +// Tests getsockopt of the default value. +TEST_P(BindToDeviceTest, GetsockoptDefault) { + char name_buffer[IFNAMSIZ * 2]; + char original_name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Read the default SO_BINDTODEVICE. + memset(original_name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(name_buffer_size, 0); + EXPECT_EQ(memcmp(name_buffer, original_name_buffer, sizeof(name_buffer)), + 0); + } +} + +// Tests setsockopt of invalid device name. +TEST_P(BindToDeviceTest, SetsockoptInvalidDeviceName) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Set an invalid device name. + memset(name_buffer, kIllegalIfnameChar, 5); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallFailsWithErrno(ENODEV)); +} + +// Tests setsockopt of a buffer with a valid device name but not +// null-terminated, with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithoutNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Intentionally overwrite the null at the end. + memset(name_buffer + interface_name().size(), kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size()); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is exactly right. + if (name_buffer_size == interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests setsockopt of a buffer with a valid device name and null-terminated, +// with different sizes of buffer. +TEST_P(BindToDeviceTest, SetsockoptValidDeviceNameWithNullTermination) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + strncpy(name_buffer, interface_name().c_str(), interface_name().size() + 1); + // Don't overwrite the null at the end. + memset(name_buffer + interface_name().size() + 1, kIllegalIfnameChar, + sizeof(name_buffer) - interface_name().size() - 1); + for (size_t i = 1; i <= sizeof(name_buffer); i++) { + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // It should only work if the size provided is at least the right size. + if (name_buffer_size >= interface_name().size()) { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallSucceeds()); + } else { + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, name_buffer_size), + SyscallFailsWithErrno(ENODEV)); + } + } +} + +// Tests that setsockopt of an invalid device name doesn't unset the previous +// valid setsockopt. +TEST_P(BindToDeviceTest, SetsockoptValidThenInvalid) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Write unsuccessfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = 5; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallFailsWithErrno(ENODEV)); + + // Read it back successfully, it's unchanged. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); +} + +// Tests that setsockopt of zero-length string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClear) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + name_buffer_size = 0; + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests that setsockopt of empty string correctly unsets the previous +// value. +TEST_P(BindToDeviceTest, SetsockoptValidThenClearWithNull) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + + // Clear it successfully. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer[0] = 0; + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + name_buffer_size), + SyscallSucceeds()); + + // Read it back successfully, it's cleared. + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = sizeof(name_buffer); + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, 0); +} + +// Tests getsockopt with different buffer sizes. +TEST_P(BindToDeviceTest, GetsockoptDevice) { + char name_buffer[IFNAMSIZ * 2]; + socklen_t name_buffer_size; + + // Write successfully. + strncpy(name_buffer, interface_name().c_str(), sizeof(name_buffer)); + ASSERT_THAT(setsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, name_buffer, + sizeof(name_buffer)), + SyscallSucceeds()); + + // Read it back at various buffer sizes. + for (size_t i = 0; i <= sizeof(name_buffer); i++) { + memset(name_buffer, kIllegalIfnameChar, sizeof(name_buffer)); + name_buffer_size = i; + SCOPED_TRACE(absl::StrCat("Buffer size: ", i)); + // Linux only allows a buffer at least IFNAMSIZ, even if less would suffice + // for this interface name. + if (name_buffer_size >= IFNAMSIZ) { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallSucceeds()); + EXPECT_EQ(name_buffer_size, interface_name().size() + 1); + EXPECT_STREQ(name_buffer, interface_name().c_str()); + } else { + EXPECT_THAT(getsockopt(socket_fd(), SOL_SOCKET, SO_BINDTODEVICE, + name_buffer, &name_buffer_size), + SyscallFailsWithErrno(EINVAL)); + EXPECT_EQ(name_buffer_size, i); + } + } +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc new file mode 100644 index 000000000..4d2400328 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -0,0 +1,381 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <atomic> +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +struct EndpointConfig { + std::string bind_to_device; + double expected_ratio; +}; + +struct DistributionTestCase { + std::string name; + std::vector<EndpointConfig> endpoints; +}; + +struct ListenerConnector { + TestAddress listener; + TestAddress connector; +}; + +// Test fixture for SO_BINDTODEVICE tests the distribution of packets received +// with varying SO_BINDTODEVICE settings. +class BindToDeviceDistributionTest + : public ::testing::TestWithParam< + ::testing::tuple<ListenerConnector, DistributionTestCase>> { + protected: + void SetUp() override { + printf("Testing case: %s, listener=%s, connector=%s\n", + ::testing::get<1>(GetParam()).name.c_str(), + ::testing::get<0>(GetParam()).listener.description.c_str(), + ::testing::get<0>(GetParam()).connector.description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + } +}; + +PosixErrorOr<uint16_t> AddrPort(int family, sockaddr_storage const& addr) { + switch (family) { + case AF_INET: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in const*>(&addr)->sin_port); + case AF_INET6: + return static_cast<uint16_t>( + reinterpret_cast<sockaddr_in6 const*>(&addr)->sin6_port); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +PosixError SetAddrPort(int family, sockaddr_storage* addr, uint16_t port) { + switch (family) { + case AF_INET: + reinterpret_cast<sockaddr_in*>(addr)->sin_port = port; + return NoError(); + case AF_INET6: + reinterpret_cast<sockaddr_in6*>(addr)->sin6_port = port; + return NoError(); + default: + return PosixError(EINVAL, + absl::StrCat("unknown socket family: ", family)); + } +} + +// Binds sockets to different devices and then creates many TCP connections. +// Checks that the distribution of connections received on the sockets matches +// the expectation. +TEST_P(BindToDeviceDistributionTest, Tcp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening sockets. + std::vector<FileDescriptor> listener_fds; + std::vector<std::unique_ptr<Tunnel>> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back(ASSERT_NO_ERRNO_AND_VALUE( + Socket(listener.family(), SOCK_STREAM, IPPROTO_TCP))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + ASSERT_THAT(listen(fd, 40), SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> connects_received = ATOMIC_VAR_INIT(0); + std::vector<int> accept_counts(listener_fds.size(), 0); + std::vector<std::unique_ptr<ScopedThread>> listen_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + listen_threads[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &accept_counts, &connects_received, i, + kConnectAttempts]() { + do { + auto fd = Accept(listener_fds[i].get(), nullptr, nullptr); + if (!fd.ok()) { + // Another thread has shutdown our read side causing the accept to + // fail. + ASSERT_GE(connects_received, kConnectAttempts) + << "errno = " << fd.error(); + return; + } + // Receive some data from a socket to be sure that the connect() + // system call has been completed on another side. + int data; + EXPECT_THAT( + RetryEINTR(recv)(fd.ValueOrDie().get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + accept_counts[i]++; + } while (++connects_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = ASSERT_NO_ERRNO_AND_VALUE( + Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); + ASSERT_THAT( + RetryEINTR(connect)(fd.get(), reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceeds()); + + EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), + SyscallSucceedsWithValue(sizeof(i))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& listen_thread : listen_threads) { + listen_thread->Join(); + } + // Check that connections are distributed correctly among listening sockets. + for (int i = 0; i < accept_counts.size(); i++) { + EXPECT_THAT( + accept_counts[i], + EquivalentWithin(static_cast<int>(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +// Binds sockets to different devices and then sends many UDP packets. Checks +// that the distribution of packets received on the sockets matches the +// expectation. +TEST_P(BindToDeviceDistributionTest, Udp) { + auto const& [listener_connector, test] = GetParam(); + + TestAddress const& listener = listener_connector.listener; + TestAddress const& connector = listener_connector.connector; + sockaddr_storage listen_addr = listener.addr; + sockaddr_storage conn_addr = connector.addr; + + auto interface_names = GetInterfaceNames(); + + // Create the listening socket. + std::vector<FileDescriptor> listener_fds; + std::vector<std::unique_ptr<Tunnel>> all_tunnels; + for (auto const& endpoint : test.endpoints) { + if (!endpoint.bind_to_device.empty() && + interface_names.find(endpoint.bind_to_device) == + interface_names.end()) { + all_tunnels.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New(endpoint.bind_to_device))); + interface_names.insert(endpoint.bind_to_device); + } + + listener_fds.push_back( + ASSERT_NO_ERRNO_AND_VALUE(Socket(listener.family(), SOCK_DGRAM, 0))); + int fd = listener_fds.back().get(); + + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceeds()); + ASSERT_THAT(setsockopt(fd, SOL_SOCKET, SO_BINDTODEVICE, + endpoint.bind_to_device.c_str(), + endpoint.bind_to_device.size() + 1), + SyscallSucceeds()); + ASSERT_THAT( + bind(fd, reinterpret_cast<sockaddr*>(&listen_addr), listener.addr_len), + SyscallSucceeds()); + + // On the first bind we need to determine which port was bound. + if (listener_fds.size() > 1) { + continue; + } + + // Get the port bound by the listening socket. + socklen_t addrlen = listener.addr_len; + ASSERT_THAT( + getsockname(listener_fds[0].get(), + reinterpret_cast<sockaddr*>(&listen_addr), &addrlen), + SyscallSucceeds()); + uint16_t const port = + ASSERT_NO_ERRNO_AND_VALUE(AddrPort(listener.family(), listen_addr)); + ASSERT_NO_ERRNO(SetAddrPort(listener.family(), &listen_addr, port)); + ASSERT_NO_ERRNO(SetAddrPort(connector.family(), &conn_addr, port)); + } + + constexpr int kConnectAttempts = 10000; + std::atomic<int> packets_received = ATOMIC_VAR_INIT(0); + std::vector<int> packets_per_socket(listener_fds.size(), 0); + std::vector<std::unique_ptr<ScopedThread>> receiver_threads( + listener_fds.size()); + + for (int i = 0; i < listener_fds.size(); i++) { + receiver_threads[i] = absl::make_unique<ScopedThread>( + [&listener_fds, &packets_per_socket, &packets_received, i]() { + do { + struct sockaddr_storage addr = {}; + socklen_t addrlen = sizeof(addr); + int data; + + auto ret = RetryEINTR(recvfrom)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<struct sockaddr*>(&addr), &addrlen); + + if (packets_received < kConnectAttempts) { + ASSERT_THAT(ret, SyscallSucceedsWithValue(sizeof(data))); + } + + if (ret != sizeof(data)) { + // Another thread may have shutdown our read side causing the + // recvfrom to fail. + break; + } + + packets_received++; + packets_per_socket[i]++; + + // A response is required to synchronize with the main thread, + // otherwise the main thread can send more than can fit into receive + // queues. + EXPECT_THAT(RetryEINTR(sendto)( + listener_fds[i].get(), &data, sizeof(data), 0, + reinterpret_cast<sockaddr*>(&addr), addrlen), + SyscallSucceedsWithValue(sizeof(data))); + } while (packets_received < kConnectAttempts); + + // Shutdown all sockets to wake up other threads. + for (auto const& listener_fd : listener_fds) { + shutdown(listener_fd.get(), SHUT_RDWR); + } + }); + } + + for (int i = 0; i < kConnectAttempts; i++) { + FileDescriptor const fd = + ASSERT_NO_ERRNO_AND_VALUE(Socket(connector.family(), SOCK_DGRAM, 0)); + EXPECT_THAT(RetryEINTR(sendto)(fd.get(), &i, sizeof(i), 0, + reinterpret_cast<sockaddr*>(&conn_addr), + connector.addr_len), + SyscallSucceedsWithValue(sizeof(i))); + int data; + EXPECT_THAT(RetryEINTR(recv)(fd.get(), &data, sizeof(data), 0), + SyscallSucceedsWithValue(sizeof(data))); + } + + // Join threads to be sure that all connections have been counted. + for (auto const& receiver_thread : receiver_threads) { + receiver_thread->Join(); + } + // Check that packets are distributed correctly among listening sockets. + for (int i = 0; i < packets_per_socket.size(); i++) { + EXPECT_THAT( + packets_per_socket[i], + EquivalentWithin(static_cast<int>(kConnectAttempts * + test.endpoints[i].expected_ratio), + 0.10)) + << "endpoint " << i << " got the wrong number of packets"; + } +} + +std::vector<DistributionTestCase> GetDistributionTestCases() { + return std::vector<DistributionTestCase>{ + {"Even distribution among sockets not bound to device", + {{"", 1. / 3}, {"", 1. / 3}, {"", 1. / 3}}}, + {"Sockets bound to other interfaces get no packets", + {{"eth1", 0}, {"", 1. / 2}, {"", 1. / 2}}}, + {"Bound has priority over unbound", {{"eth1", 0}, {"", 0}, {"lo", 1}}}, + {"Even distribution among sockets bound to device", + {{"eth1", 0}, {"lo", 1. / 2}, {"lo", 1. / 2}}}, + }; +} + +INSTANTIATE_TEST_SUITE_P( + BindToDeviceTest, BindToDeviceDistributionTest, + ::testing::Combine(::testing::Values( + // Listeners bound to IPv4 addresses refuse + // connections using IPv6 addresses. + ListenerConnector{V4Any(), V4Loopback()}, + ListenerConnector{V4Loopback(), V4MappedLoopback()}), + ::testing::ValuesIn(GetDistributionTestCases()))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_sequence.cc b/test/syscalls/linux/socket_bind_to_device_sequence.cc new file mode 100644 index 000000000..a7365d139 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_sequence.cc @@ -0,0 +1,316 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <arpa/inet.h> +#include <linux/capability.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "gtest/gtest.h" +#include "test/syscalls/linux/ip_socket_test_util.h" +#include "test/syscalls/linux/socket_bind_to_device_util.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/test_util.h" +#include "test/util/thread_util.h" + +namespace gvisor { +namespace testing { + +using std::string; +using std::vector; + +// Test fixture for SO_BINDTODEVICE tests the results of sequences of socket +// binding. +class BindToDeviceSequenceTest : public ::testing::TestWithParam<SocketKind> { + protected: + void SetUp() override { + printf("Testing case: %s\n", GetParam().description.c_str()); + ASSERT_TRUE(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))) + << "CAP_NET_RAW is required to use SO_BINDTODEVICE"; + socket_factory_ = GetParam(); + + interface_names_ = GetInterfaceNames(); + } + + PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const { + return socket_factory_.Create(); + } + + // Gets a device by device_id. If the device_id has been seen before, returns + // the previously returned device. If not, finds or creates a new device. + // Returns an empty string on failure. + void GetDevice(int device_id, string *device_name) { + auto device = devices_.find(device_id); + if (device != devices_.end()) { + *device_name = device->second; + return; + } + + // Need to pick a new device. Try ethernet first. + *device_name = absl::StrCat("eth", next_unused_eth_); + if (interface_names_.find(*device_name) != interface_names_.end()) { + devices_[device_id] = *device_name; + next_unused_eth_++; + return; + } + + // Need to make a new tunnel device. gVisor tests should have enough + // ethernet devices to never reach here. + ASSERT_FALSE(IsRunningOnGvisor()); + // Need a tunnel. + tunnels_.push_back(ASSERT_NO_ERRNO_AND_VALUE(Tunnel::New())); + devices_[device_id] = tunnels_.back()->GetName(); + *device_name = devices_[device_id]; + } + + // Release the socket + void ReleaseSocket(int socket_id) { + // Close the socket that was made in a previous action. The socket_id + // indicates which socket to close based on index into the list of actions. + sockets_to_close_.erase(socket_id); + } + + // Bind a socket with the reuse option and bind_to_device options. Checks + // that all steps succeed and that the bind command's error matches want. + // Sets the socket_id to uniquely identify the socket bound if it is not + // nullptr. + void BindSocket(bool reuse, int device_id = 0, int want = 0, + int *socket_id = nullptr) { + next_socket_id_++; + sockets_to_close_[next_socket_id_] = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto socket_fd = sockets_to_close_[next_socket_id_]->get(); + if (socket_id != nullptr) { + *socket_id = next_socket_id_; + } + + // If reuse is indicated, do that. + if (reuse) { + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_REUSEPORT, &kSockOptOn, + sizeof(kSockOptOn)), + SyscallSucceedsWithValue(0)); + } + + // If the device is non-zero, bind to that device. + if (device_id != 0) { + string device_name; + ASSERT_NO_FATAL_FAILURE(GetDevice(device_id, &device_name)); + EXPECT_THAT(setsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, + device_name.c_str(), device_name.size() + 1), + SyscallSucceedsWithValue(0)); + char get_device[100]; + socklen_t get_device_size = 100; + EXPECT_THAT(getsockopt(socket_fd, SOL_SOCKET, SO_BINDTODEVICE, get_device, + &get_device_size), + SyscallSucceedsWithValue(0)); + } + + struct sockaddr_in addr = {}; + addr.sin_family = AF_INET; + addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr.sin_port = port_; + if (want == 0) { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + sizeof(addr)), + SyscallSucceeds()); + } else { + ASSERT_THAT( + bind(socket_fd, reinterpret_cast<const struct sockaddr *>(&addr), + sizeof(addr)), + SyscallFailsWithErrno(want)); + } + + if (port_ == 0) { + // We don't yet know what port we'll be using so we need to fetch it and + // remember it for future commands. + socklen_t addr_size = sizeof(addr); + ASSERT_THAT( + getsockname(socket_fd, reinterpret_cast<struct sockaddr *>(&addr), + &addr_size), + SyscallSucceeds()); + port_ = addr.sin_port; + } + } + + private: + SocketKind socket_factory_; + // devices maps from the device id in the test case to the name of the device. + std::unordered_map<int, string> devices_; + // These are the tunnels that were created for the test and will be destroyed + // by the destructor. + vector<std::unique_ptr<Tunnel>> tunnels_; + // A list of all interface names before the test started. + std::unordered_set<string> interface_names_; + // The next ethernet device to use when requested a device. + int next_unused_eth_ = 1; + // The port for all tests. Originally 0 (any) and later set to the port that + // all further commands will use. + in_port_t port_ = 0; + // sockets_to_close_ is a map from action index to the socket that was + // created. + std::unordered_map<int, + std::unique_ptr<gvisor::testing::FileDescriptor>> + sockets_to_close_; + int next_socket_id_ = 0; +}; + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithDeviceFails) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 3, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 1)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 2)); +} + +TEST_P(BindToDeviceSequenceTest, BindToDeviceAndThenWithoutDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithoutDevice) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ false)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindWithReuse) { + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true, /* bind_to_device */ 0)); +} + +TEST_P(BindToDeviceSequenceTest, BindingWithReuseAndDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 0, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE(BindSocket(/* reuse */ true)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, MixingReuseAndNotReuseByBindingToDevice) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789, 0)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 999, 0)); +} + +TEST_P(BindToDeviceSequenceTest, CannotBindTo0AfterMixingReuseAndNotReuse) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 456)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +TEST_P(BindToDeviceSequenceTest, BindAndRelease) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 123)); + int to_release; + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, 0, &to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345, EADDRINUSE)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 789)); + // Release the bind to device 0 and try again. + ASSERT_NO_FATAL_FAILURE(ReleaseSocket(to_release)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 345)); +} + +TEST_P(BindToDeviceSequenceTest, BindTwiceWithReuseOnce) { + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ false, /* bind_to_device */ 123)); + ASSERT_NO_FATAL_FAILURE( + BindSocket(/* reuse */ true, /* bind_to_device */ 0, EADDRINUSE)); +} + +INSTANTIATE_TEST_SUITE_P(BindToDeviceTest, BindToDeviceSequenceTest, + ::testing::Values(IPv4UDPUnboundSocket(0), + IPv4TCPUnboundSocket(0))); + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.cc b/test/syscalls/linux/socket_bind_to_device_util.cc new file mode 100644 index 000000000..f4ee775bd --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.cc @@ -0,0 +1,75 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/syscalls/linux/socket_bind_to_device_util.h" + +#include <arpa/inet.h> +#include <fcntl.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +using std::string; + +PosixErrorOr<std::unique_ptr<Tunnel>> Tunnel::New(string tunnel_name) { + int fd; + RETURN_ERROR_IF_SYSCALL_FAIL(fd = open("/dev/net/tun", O_RDWR)); + + // Using `new` to access a non-public constructor. + auto new_tunnel = absl::WrapUnique(new Tunnel(fd)); + + ifreq ifr = {}; + ifr.ifr_flags = IFF_TUN; + strncpy(ifr.ifr_name, tunnel_name.c_str(), sizeof(ifr.ifr_name)); + + RETURN_ERROR_IF_SYSCALL_FAIL(ioctl(fd, TUNSETIFF, &ifr)); + new_tunnel->name_ = ifr.ifr_name; + return new_tunnel; +} + +std::unordered_set<string> GetInterfaceNames() { + struct if_nameindex* interfaces = if_nameindex(); + std::unordered_set<string> names; + if (interfaces == nullptr) { + return names; + } + for (auto interface = interfaces; + interface->if_index != 0 || interface->if_name != nullptr; interface++) { + names.insert(interface->if_name); + } + if_freenameindex(interfaces); + return names; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/socket_bind_to_device_util.h b/test/syscalls/linux/socket_bind_to_device_util.h new file mode 100644 index 000000000..f941ccc86 --- /dev/null +++ b/test/syscalls/linux/socket_bind_to_device_util.h @@ -0,0 +1,67 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ +#define GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ + +#include <arpa/inet.h> +#include <linux/if_tun.h> +#include <net/if.h> +#include <netinet/in.h> +#include <sys/ioctl.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <sys/un.h> +#include <unistd.h> + +#include <cstdio> +#include <cstring> +#include <map> +#include <memory> +#include <string> +#include <unordered_map> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "absl/memory/memory.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +class Tunnel { + public: + static PosixErrorOr<std::unique_ptr<Tunnel>> New( + std::string tunnel_name = ""); + const std::string& GetName() const { return name_; } + + ~Tunnel() { + if (fd_ != -1) { + close(fd_); + } + } + + private: + Tunnel(int fd) : fd_(fd) {} + int fd_ = -1; + std::string name_; +}; + +std::unordered_set<std::string> GetInterfaceNames(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_SOCKET_BIND_TO_DEVICE_UTILS_H_ diff --git a/test/syscalls/linux/uidgid.cc b/test/syscalls/linux/uidgid.cc index d48453a93..6218fbce1 100644 --- a/test/syscalls/linux/uidgid.cc +++ b/test/syscalls/linux/uidgid.cc @@ -25,6 +25,7 @@ #include "test/util/posix_error.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" +#include "test/util/uid_util.h" ABSL_FLAG(int32_t, scratch_uid1, 65534, "first scratch UID"); ABSL_FLAG(int32_t, scratch_uid2, 65533, "second scratch UID"); @@ -68,30 +69,6 @@ TEST(UidGidTest, Getgroups) { // here; see the setgroups test below. } -// If the caller's real/effective/saved user/group IDs are all 0, IsRoot returns -// true. Otherwise IsRoot logs an explanatory message and returns false. -PosixErrorOr<bool> IsRoot() { - uid_t ruid, euid, suid; - int rc = getresuid(&ruid, &euid, &suid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresuid"); - } - if (ruid != 0 || euid != 0 || suid != 0) { - return false; - } - gid_t rgid, egid, sgid; - rc = getresgid(&rgid, &egid, &sgid); - MaybeSave(); - if (rc < 0) { - return PosixError(errno, "getresgid"); - } - if (rgid != 0 || egid != 0 || sgid != 0) { - return false; - } - return true; -} - // Checks that the calling process' real/effective/saved user IDs are // ruid/euid/suid respectively. PosixError CheckUIDs(uid_t ruid, uid_t euid, uid_t suid) { diff --git a/test/util/BUILD b/test/util/BUILD index 25ed9c944..5d2a9cc2c 100644 --- a/test/util/BUILD +++ b/test/util/BUILD @@ -324,3 +324,14 @@ cc_library( ":test_util", ], ) + +cc_library( + name = "uid_util", + testonly = 1, + srcs = ["uid_util.cc"], + hdrs = ["uid_util.h"], + deps = [ + ":posix_error", + ":save_util", + ], +) diff --git a/test/util/uid_util.cc b/test/util/uid_util.cc new file mode 100644 index 000000000..b131b4b99 --- /dev/null +++ b/test/util/uid_util.cc @@ -0,0 +1,44 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "test/util/posix_error.h" +#include "test/util/save_util.h" + +namespace gvisor { +namespace testing { + +PosixErrorOr<bool> IsRoot() { + uid_t ruid, euid, suid; + int rc = getresuid(&ruid, &euid, &suid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresuid"); + } + if (ruid != 0 || euid != 0 || suid != 0) { + return false; + } + gid_t rgid, egid, sgid; + rc = getresgid(&rgid, &egid, &sgid); + MaybeSave(); + if (rc < 0) { + return PosixError(errno, "getresgid"); + } + if (rgid != 0 || egid != 0 || sgid != 0) { + return false; + } + return true; +} + +} // namespace testing +} // namespace gvisor diff --git a/test/util/uid_util.h b/test/util/uid_util.h new file mode 100644 index 000000000..2cd387fb0 --- /dev/null +++ b/test/util/uid_util.h @@ -0,0 +1,29 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GVISOR_TEST_SYSCALLS_UID_UTIL_H_ +#define GVISOR_TEST_SYSCALLS_UID_UTIL_H_ + +#include "test/util/posix_error.h" + +namespace gvisor { +namespace testing { + +// Returns true if the caller's real/effective/saved user/group IDs are all 0. +PosixErrorOr<bool> IsRoot(); + +} // namespace testing +} // namespace gvisor + +#endif // GVISOR_TEST_SYSCALLS_UID_UTIL_H_ |