diff options
46 files changed, 1064 insertions, 139 deletions
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 149375896..638942a42 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,5 @@ +# Contributing + Want to contribute? Great! First, read this page. ### Contributor License Agreement @@ -86,6 +88,11 @@ Code changes are accepted via [pull request][github]. When approved, the change will be submitted by a team member and automatically merged into the repository. +### Presubmit checks + +Accessing check logs may require membership in the +[gvisor-dev mailing list][gvisor-dev-list], which is public. + ### Bug IDs Some TODOs and NOTEs sprinkled throughout the code have associated IDs of the @@ -103,5 +110,6 @@ one above, the [gcla]: https://cla.developers.google.com/about/google-individual [gccla]: https://cla.developers.google.com/about/google-corporate [github]: https://github.com/google/gvisor/compare +[gvisor-dev-list]: https://groups.google.com/forum/#!forum/gvisor-dev [gostyle]: https://github.com/golang/go/wiki/CodeReviewComments [teststyle]: ./test/ diff --git a/pkg/sentry/fs/dev/tty.go b/pkg/sentry/fs/dev/tty.go index b4c2a62fd..87d80e292 100644 --- a/pkg/sentry/fs/dev/tty.go +++ b/pkg/sentry/fs/dev/tty.go @@ -32,6 +32,7 @@ type ttyInodeOperations struct { fsutil.InodeNoopWriteOut `state:"nosave"` fsutil.InodeNotDirectory `state:"nosave"` fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotOpenable `state:"nosave"` fsutil.InodeNotSocket `state:"nosave"` fsutil.InodeNotSymlink `state:"nosave"` fsutil.InodeVirtual `state:"nosave"` @@ -47,11 +48,6 @@ func newTTYDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) } } -// GetFile implements fs.InodeOperations.GetFile. -func (*ttyInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &ttyFileOperations{}), nil -} - // +stateify savable type ttyFileOperations struct { fsutil.FileNoSeek `state:"nosave"` diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 24b769cfc..e0602da17 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -339,7 +339,9 @@ func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child * } } if child.Inode.overlay.lowerExists { - return overlayCreateWhiteout(o.upper, child.name) + if err := overlayCreateWhiteout(o.upper, child.name); err != nil { + return err + } } // We've removed from the directory so we must drop the cache. o.markDirectoryDirty() @@ -418,10 +420,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena return err } if renamed.Inode.overlay.lowerExists { - return overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName) + if err := overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName); err != nil { + return err + } } // We've changed the directory so we must drop the cache. - o.markDirectoryDirty() + oldParent.Inode.overlay.markDirectoryDirty() return nil } diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 37694620c..6b839685b 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -155,37 +155,40 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se contents[1] = " face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n" for _, i := range interfaces { - // TODO(b/71872867): Collect stats from each inet.Stack - // implementation (hostinet, epsocket, and rpcinet). - // Implements the same format as // net/core/net-procfs.c:dev_seq_printf_stats. - l := fmt.Sprintf("%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", + var stats inet.StatDev + if err := n.s.Statistics(&stats, i.Name); err != nil { + log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err) + continue + } + l := fmt.Sprintf( + "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", i.Name, // Received - 0, // bytes - 0, // packets - 0, // errors - 0, // dropped - 0, // fifo - 0, // frame - 0, // compressed - 0, // multicast + stats[0], // bytes + stats[1], // packets + stats[2], // errors + stats[3], // dropped + stats[4], // fifo + stats[5], // frame + stats[6], // compressed + stats[7], // multicast // Transmitted - 0, // bytes - 0, // packets - 0, // errors - 0, // dropped - 0, // fifo - 0, // frame - 0, // compressed - 0) // multicast + stats[8], // bytes + stats[9], // packets + stats[10], // errors + stats[11], // dropped + stats[12], // fifo + stats[13], // frame + stats[14], // compressed + stats[15]) // multicast contents = append(contents, l) } var data []seqfile.SeqData for _, l := range contents { - data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*ifinet6)(nil)}) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netDev)(nil)}) } return data, 0 diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index ef0ca3301..87184ec67 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -162,6 +162,11 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry // subtask to emit. offset := file.Offset() + tasks := f.t.ThreadGroup().MemberIDs(f.pidns) + if len(tasks) == 0 { + return offset, syserror.ENOENT + } + if offset == 0 { // Serialize "." and "..". root := fs.RootFromContext(ctx) @@ -178,12 +183,12 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry } // Serialize tasks. - tasks := f.t.ThreadGroup().MemberIDs(f.pidns) taskInts := make([]int, 0, len(tasks)) for _, tid := range tasks { taskInts = append(taskInts, int(tid)) } + sort.Sort(sort.IntSlice(taskInts)) // Find the task to start at. idx := sort.SearchInts(taskInts, int(offset)) if idx == len(taskInts) { diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 7c104fd47..5b75a4a06 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -49,6 +49,9 @@ type Stack interface { // SetTCPSACKEnabled attempts to change TCP selective acknowledgement // settings. SetTCPSACKEnabled(enabled bool) error + + // Statistics reports stack statistics. + Statistics(stat interface{}, arg string) error } // Interface contains information about a network interface. @@ -102,3 +105,7 @@ type TCPBufferSize struct { // Max is the maximum size. Max int } + +// StatDev describes one line of /proc/net/dev, i.e., stats for one network +// interface. +type StatDev [16]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 624371eb6..75f9e7a77 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -81,3 +81,8 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { s.TCPSACKFlag = enabled return nil } + +// Statistics implements inet.Stack.Statistics. +func (s *TestStack) Statistics(stat interface{}, arg string) error { + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 8d76e106e..405e00292 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -21,6 +21,7 @@ package kvm import ( "fmt" + "math" "sync/atomic" "syscall" "unsafe" @@ -134,7 +135,7 @@ func (c *vCPU) notify() { syscall.SYS_FUTEX, uintptr(unsafe.Pointer(&c.state)), linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG, - ^uintptr(0), // Number of waiters. + math.MaxInt32, // Number of waiters. 0, 0, 0) if errno != 0 { throw("futex wake error") diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 9d1bcfd41..69eff7373 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -867,6 +867,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(v), nil + case linux.TCP_MAXSEG: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.MaxSegOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1219,6 +1231,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + case linux.TCP_MAXSEG: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 6d2b5d038..421f93dc4 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -40,42 +40,49 @@ type provider struct { } // getTransportProtocol figures out transport protocol. Currently only TCP, -// UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +// UDP, and ICMP are supported. The bool return value is true when this socket +// is associated with a transport protocol. This is only false for SOCK_RAW, +// IPPROTO_IP sockets. +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { - return 0, syserr.ErrInvalidArgument + return 0, true, syserr.ErrInvalidArgument } - return tcp.ProtocolNumber, nil + return tcp.ProtocolNumber, true, nil case linux.SOCK_DGRAM: switch protocol { case 0, syscall.IPPROTO_UDP: - return udp.ProtocolNumber, nil + return udp.ProtocolNumber, true, nil case syscall.IPPROTO_ICMP: - return header.ICMPv4ProtocolNumber, nil + return header.ICMPv4ProtocolNumber, true, nil case syscall.IPPROTO_ICMPV6: - return header.ICMPv6ProtocolNumber, nil + return header.ICMPv6ProtocolNumber, true, nil } case linux.SOCK_RAW: // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, syserr.ErrPermissionDenied + return 0, true, syserr.ErrPermissionDenied } switch protocol { case syscall.IPPROTO_ICMP: - return header.ICMPv4ProtocolNumber, nil + return header.ICMPv4ProtocolNumber, true, nil case syscall.IPPROTO_UDP: - return header.UDPProtocolNumber, nil + return header.UDPProtocolNumber, true, nil case syscall.IPPROTO_TCP: - return header.TCPProtocolNumber, nil + return header.TCPProtocolNumber, true, nil + // IPPROTO_RAW signifies that the raw socket isn't assigned to + // a transport protocol. Users will be able to write packets' + // IP headers and won't receive anything. + case syscall.IPPROTO_RAW: + return tcpip.TransportProtocolNumber(0), false, nil } } - return 0, syserr.ErrProtocolNotSupported + return 0, true, syserr.ErrProtocolNotSupported } // Socket creates a new socket object for the AF_INET or AF_INET6 family. @@ -93,7 +100,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* } // Figure out the transport protocol. - transProto, err := getTransportProtocol(t, stype, protocol) + transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { return nil, err } @@ -103,7 +110,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* var e *tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { - ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq) + ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) } diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 1627a4f68..7eef19f74 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -138,3 +138,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserr.ErrEndpointOperation.ToError() +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 11f94281c..cc1f66fa1 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -244,3 +244,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserror.EOPNOTSUPP +} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 3038f25a7..49bd3a220 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -133,3 +133,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { panic("rpcinet handles procfs directly this method should not be called") } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserr.ErrEndpointOperation.ToError() +} diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 680c0c64c..51db2d8f7 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -234,9 +234,9 @@ var AMD64 = &kernel.SyscallTable{ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) - 188: syscalls.ErrorWithEvent("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 189: syscalls.ErrorWithEvent("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 190: syscalls.ErrorWithEvent("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c081de61f..c52c0d851 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -24,15 +24,11 @@ import ( type ICMPv4 []byte const ( - // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. - ICMPv4MinimumSize = 4 - - // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. - ICMPv4EchoMinimumSize = 6 + // ICMPv4PayloadOffset defines the start of ICMP payload. + ICMPv4PayloadOffset = 4 - // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP - // destination unreachable packet. - ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -104,5 +100,5 @@ func (ICMPv4) SetDestinationPort(uint16) { // Payload implements Transport.Payload. func (b ICMPv4) Payload() []byte { - return b[ICMPv4MinimumSize:] + return b[ICMPv4PayloadOffset:] } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 7da4c4845..94a3af289 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -85,6 +85,10 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that + // the first fragment must carry when an IPv4 packet is fragmented. + MinIPFragmentPayloadSize = 8 + // IPv4AddressSize is the size, in bytes, of an IPv4 address. IPv4AddressSize = 4 @@ -268,6 +272,10 @@ func (b IPv4) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv4Version { + return false + } + return true } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 7163eaa36..95fe8bfc3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -184,6 +184,10 @@ func (b IPv6) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv6Version { + return false + } + return true } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 1141443bb..82cfe785c 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -176,6 +176,21 @@ const ( // TCPProtocolNumber is TCP's transport protocol number. TCPProtocolNumber tcpip.TransportProtocolNumber = 6 + + // TCPMinimumMSS is the minimum acceptable value for MSS. This is the + // same as the value TCP_MIN_MSS defined net/tcp.h. + TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize + + // TCPMaximumMSS is the maximum acceptable value for MSS. + TCPMaximumMSS = 0xffff + + // TCPDefaultMSS is the MSS value that should be used if an MSS option + // is not received from the peer. It's also the value returned by + // TCP_MAXSEG option for a socket in an unconnected state. + // + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + TCPDefaultMSS = 536 ) // SourcePort returns the "source port" field of the tcp header. @@ -306,7 +321,7 @@ func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { synOpts := TCPSynOptions{ // Per RFC 1122, page 85: "If an MSS option is not received at // connection setup, TCP MUST assume a default send MSS of 536." - MSS: 536, + MSS: TCPDefaultMSS, // If no window scale option is specified, WS in options is // returned as -1; this is because the absence of the option // indicates that the we cannot use window scaling on the diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ca3d6c0bf..cb35635fc 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -83,6 +83,10 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf return tcpip.ErrNotSupported } +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() h := header.ARP(v) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index db65ee7cc..8ff428445 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -282,10 +282,10 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") if err != nil { @@ -301,7 +301,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } defer ep.Close() - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) // Create the outer IPv4 header. @@ -319,10 +319,10 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: 100, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index bc7f1c42a..fbef6947d 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -68,10 +68,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } // Only send a reply if the checksum is valid. wantChecksum := h.Checksum() @@ -93,9 +89,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) vv := vv.Clone(nil) - vv.TrimFront(header.ICMPv4EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + vv.TrimFront(header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) @@ -108,25 +104,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv4DstUnreachableMinimumSize { - received.Invalid.Increment() - return - } - vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + + vv.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1e3a7425a..e44a73d96 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -232,6 +232,55 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // The packet already has an IP header, but there are a few required + // checks. + ip := header.IPv4(payload.First()) + if !ip.IsValid(payload.Size()) { + return tcpip.ErrInvalidOptionValue + } + + // Always set the total length. + ip.SetTotalLength(uint16(payload.Size())) + + // Set the source address when zero. + if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, + // it will be part of the route. + ip.SetDestinationAddress(r.RemoteAddress) + + // Set the packet ID when zero. + if ip.ID() == 0 { + id := uint32(0) + if payload.Size() > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + } + ip.SetID(uint16(id)) + } + + // Always set the checksum. + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + if loop&stack.PacketLoop != 0 { + e.HandlePacket(r, payload) + } + if loop&stack.PacketOut == 0 { + return nil + } + + hdr := buffer.NewPrependableFromView(payload.ToView()) + r.Stats().IP.PacketsSent.Increment() + return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 27367d6c5..e3e8739fd 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -120,6 +120,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet +// supported by IPv6. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // TODO(b/119580726): Support IPv6 header-included packets. + return tcpip.ErrNotSupported +} + // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0ecaa0833..462265281 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -174,6 +174,10 @@ type NetworkEndpoint interface { // protocol. WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + // WriteHeaderIncludedPacket writes a packet that includes a network + // header to the given destination address. + WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol // instantiate network protocols. type NetworkProtocolFactory func() NetworkProtocol +// UnassociatedEndpointFactory produces endpoints for writing packets not +// associated with a particular transport protocol. Such endpoints can be used +// to write arbitrary packets that include the IP header. +type UnassociatedEndpointFactory interface { + NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +} + var ( transportProtocols = make(map[string]TransportProtocolFactory) networkProtocols = make(map[string]NetworkProtocolFactory) + unassociatedFactory UnassociatedEndpointFactory + linkEPMu sync.RWMutex nextLinkEndpointID tcpip.LinkEndpointID = 1 linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) @@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { networkProtocols[name] = p } +// RegisterUnassociatedFactory registers a factory to produce endpoints not +// associated with any particular transport protocol. This function is intended +// to be called by init() functions of the protocols. +func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { + unassociatedFactory = f +} + // RegisterLinkEndpoint register a link-layer protocol endpoint and returns an // ID that can be used to refer to it. func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 36d7b6ac7..391ab4344 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec return err } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.ref.nic.stats.Tx.Packets.Increment() + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + return nil +} + // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { return r.ref.ep.DefaultTTL() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2d7f56ca9..3e8fb2a6c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -340,6 +340,8 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + unassociatedFactory UnassociatedEndpointFactory + demux *transportDemuxer stats tcpip.Stats @@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack { } } + s.unassociatedFactory = unassociatedFactory + // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if !s.raw { return nil, tcpip.ErrNotPermitted } + if !associated { + return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + } + t, ok := s.transportProtocols[transport] if !ok { return nil, tcpip.ErrUnknownProtocol diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 69884af03..959071dbe 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) } +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (*fakeNetworkEndpoint) Close() {} type fakeNetGoodOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c61f96fb0..c4076666a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -496,6 +496,10 @@ type AvailableCongestionControlOption string // buffer moderation. type ModerateReceiveBufferOption bool +// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current +// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. +type MaxSegOption int + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ab9e80747..a80ceafd0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = e.send4(route, v) + err = send4(route, e.id.LocalPort, v) case header.IPv6ProtocolNumber: err = send6(route, e.id.LocalPort, v) @@ -352,20 +352,20 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { - if len(data) < header.ICMPv4EchoMinimumSize { +func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } // Set the ident to the user-specified port. Sequence number should // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) + binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) - data = data[header.ICMPv4EchoMinimumSize:] + data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index c89538131..7fdba5d56 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -90,19 +90,18 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt func (p *protocol) MinimumPacketSize() int { switch p.number { case ProtocolNumber4: - return header.ICMPv4EchoMinimumSize + return header.ICMPv4MinimumSize case ProtocolNumber6: return header.ICMPv6EchoMinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// ParsePorts returns the source and destination ports stored in the given icmp -// packet. +// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil + return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil case ProtocolNumber6: return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 34a14bf7f..bc4b255b4 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -21,6 +21,7 @@ go_library( "endpoint.go", "endpoint_state.go", "packet_list.go", + "protocol.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 42aded77f..a29587658 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -67,6 +67,7 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + associated bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -97,8 +98,12 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + associated: associated, + } + + // Unassociated endpoints are write-only and users call Write() with IP + // headers included. Because they're write-only, We don't need to + // register with the stack. + if !associated { + ep.rcvBufSizeMax = 0 + ep.waiterQueue = nil + return ep, nil } if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { @@ -124,7 +139,7 @@ func (ep *endpoint) Close() { ep.mu.Lock() defer ep.mu.Unlock() - if ep.closed { + if ep.closed || !ep.associated { return } @@ -142,8 +157,11 @@ func (ep *endpoint) Close() { if ep.connected { ep.route.Release() + ep.connected = false } + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !ep.associated { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue + } + ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + // If this is an unassociated socket and callee provided a nonzero + // destination address, route using that address. + if !ep.associated { + ip := header.IPv4(payloadBytes) + if !ip.IsValid(payload.Size()) { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidOptionValue + } + dstAddr := ip.DestinationAddress() + // Update dstAddr with the address in the IP header, unless + // opts.To is set (e.g. if sendto specifies a specific + // address). + if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { + opts.To = &tcpip.FullAddress{ + NIC: 0, // NIC is unset. + Addr: dstAddr, // The address from the payload. + Port: 0, // There are no ports here. + } + } + } + // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payload, savedRoute) + n, ch, err := ep.finishWrite(payloadBytes, savedRoute) ep.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payload, &ep.route) + n, ch, err := ep.finishWrite(payloadBytes, &ep.route) ep.mu.RUnlock() return n, ch, err } @@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, err } - n, ch, err := ep.finishWrite(payload, &route) + n, ch, err := ep.finishWrite(payloadBytes, &route) route.Release() ep.mu.RUnlock() return n, ch, err @@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint } } - payloadBytes, err := payload.Get(payload.Size()) - if err != nil { - return 0, nil, err - } - switch ep.netProto { case header.IPv4ProtocolNumber: + if !ep.associated { + if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + return 0, nil, err + } + break + } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { return 0, nil, err @@ -335,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } defer route.Release() - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = nic } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - // Save the route and NIC we've connected via. + // Save the route we've connected via. ep.route = route.Clone() - ep.registeredNIC = nic ep.connected = true return nil @@ -386,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrBadLocalAddress } - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC ep.boundAddr = addr.Addr ep.bound = true diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go new file mode 100644 index 000000000..783c21e6b --- /dev/null +++ b/pkg/tcpip/transport/raw/protocol.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +type factory struct{} + +// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. +func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) +} + +func init() { + stack.RegisterUnassociatedFactory(factory{}) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cb40fea94..beb90afb5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -117,6 +117,7 @@ const ( notifyDrain notifyReset notifyKeepaliveChanged + notifyMSSChanged ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -218,8 +219,6 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` id stack.TransportEndpointID - // state endpointState `state:".(endpointState)"` - // pState ProtocolState state EndpointState `state:".(EndpointState)"` isPortReserved bool `state:"manual"` @@ -313,6 +312,10 @@ type endpoint struct { // in SYN-RCVD state. synRcvdCount int + // userMSS if non-zero is the MSS value explicitly set by the user + // for this endpoint using the TCP_MAXSEG setsockopt. + userMSS int + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -917,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1096,6 +1110,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err + case *tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + *o = header.TCPDefaultMSS + return nil + case *tcpip.SendBufferSizeOption: e.sndBufMu.Lock() *o = tcpip.SendBufferSizeOption(e.sndBufSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 630dd7925..bcc0f3e28 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -271,7 +271,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -291,8 +291,8 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp.SetType(typ) icmp.SetCode(code) - copy(icmp[header.ICMPv4MinimumSize:], p1) - copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + copy(icmp[header.ICMPv4PayloadOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) diff --git a/runsc/boot/fs.go b/runsc/boot/fs.go index f9a6f2d3c..d3e3196fd 100644 --- a/runsc/boot/fs.go +++ b/runsc/boot/fs.go @@ -906,7 +906,10 @@ func (c *containerMounter) mountSharedSubmount(ctx context.Context, mns *fs.Moun } defer target.DecRef() + // Take a ref on the inode that is about to be (re)-mounted. + source.root.IncRef() if err := mns.Mount(ctx, target, source.root); err != nil { + source.root.DecRef() return fmt.Errorf("bind mount %q error: %v", mount.Destination, err) } diff --git a/runsc/test/integration/regression_test.go b/runsc/test/integration/regression_test.go index 39b30e757..fb68dda99 100644 --- a/runsc/test/integration/regression_test.go +++ b/runsc/test/integration/regression_test.go @@ -32,7 +32,7 @@ func TestBindOverlay(t *testing.T) { } d := testutil.MakeDocker("bind-overlay-test") - cmd := "nc -l -U /var/run/sock& sleep 1 && echo foobar-asdf | nc -U /var/run/sock" + cmd := "nc -l -U /var/run/sock & p=$! && sleep 1 && echo foobar-asdf | nc -U /var/run/sock && wait $p" got, err := d.RunFg("ubuntu:trusty", "bash", "-c", cmd) if err != nil { t.Fatal("docker run failed:", err) diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index 4ac511740..c54d90a1c 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -319,6 +319,10 @@ syscall_test( test = "//test/syscalls/linux:pwrite64_test", ) +syscall_test(test = "//test/syscalls/linux:raw_socket_hdrincl_test") + +syscall_test(test = "//test/syscalls/linux:raw_socket_icmp_test") + syscall_test(test = "//test/syscalls/linux:raw_socket_ipv4_test") syscall_test( diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 4735284ba..c5a368463 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -1562,6 +1562,24 @@ cc_binary( ) cc_binary( + name = "raw_socket_hdrincl_test", + testonly = 1, + srcs = ["raw_socket_hdrincl.cc"], + linkstatic = 1, + deps = [ + ":socket_test_util", + ":unix_domain_socket_test_util", + "//test/util:capability_util", + "//test/util:file_descriptor", + "//test/util:test_main", + "//test/util:test_util", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:endian", + "@com_google_googletest//:gtest", + ], +) + +cc_binary( name = "raw_socket_ipv4_test", testonly = 1, srcs = ["raw_socket_ipv4.cc"], diff --git a/test/syscalls/linux/dev.cc b/test/syscalls/linux/dev.cc index 7f0a2ec45..4dd302eed 100644 --- a/test/syscalls/linux/dev.cc +++ b/test/syscalls/linux/dev.cc @@ -16,6 +16,7 @@ #include <sys/stat.h> #include <sys/types.h> #include <unistd.h> + #include <vector> #include "gmock/gmock.h" @@ -150,8 +151,6 @@ TEST(DevTest, TTYExists) { ASSERT_THAT(stat("/dev/tty", &statbuf), SyscallSucceeds()); // Check that it's a character device with rw-rw-rw- permissions. EXPECT_EQ(statbuf.st_mode, S_IFCHR | 0666); - // Check that it's owned by root. - EXPECT_EQ(statbuf.st_uid, 0); } } // namespace diff --git a/test/syscalls/linux/getdents.cc b/test/syscalls/linux/getdents.cc index 8e4efa8d6..fe9cfafe8 100644 --- a/test/syscalls/linux/getdents.cc +++ b/test/syscalls/linux/getdents.cc @@ -40,6 +40,7 @@ #include "test/util/temp_path.h" #include "test/util/test_util.h" +using ::testing::Contains; using ::testing::IsEmpty; using ::testing::IsSupersetOf; using ::testing::Not; @@ -484,6 +485,44 @@ TEST(ReaddirTest, Bug35110122Root) { EXPECT_THAT(contents, Not(IsEmpty())); } +// Unlink should invalidate getdents cache. +TEST(ReaddirTest, GoneAfterRemoveCache) { + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(dir.path())); + std::string name = std::string(Basename(file.path())); + + auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true)); + EXPECT_THAT(contents, Contains(name)); + + file.reset(); + + contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dir.path(), true)); + EXPECT_THAT(contents, Not(Contains(name))); +} + +// Regression test for b/137398511. Rename should invalidate getdents cache. +TEST(ReaddirTest, GoneAfterRenameCache) { + TempPath src = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TempPath dst = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + + TempPath file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(src.path())); + std::string name = std::string(Basename(file.path())); + + auto contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true)); + EXPECT_THAT(contents, Contains(name)); + + ASSERT_THAT(rename(file.path().c_str(), JoinPath(dst.path(), name).c_str()), + SyscallSucceeds()); + // Release file since it was renamed. dst cleanup will ultimately delete it. + file.release(); + + contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(src.path(), true)); + EXPECT_THAT(contents, Not(Contains(name))); + + contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir(dst.path(), true)); + EXPECT_THAT(contents, Contains(name)); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 490bf4424..b440ba0df 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -1964,6 +1964,22 @@ TEST(ProcPid, RootDumpableOwner) { EXPECT_THAT(st.st_gid, AnyOf(Eq(0), Eq(65534))); } +TEST(Proc, GetdentsEnoent) { + FileDescriptor fd; + ASSERT_NO_ERRNO(WithSubprocess( + [&](int pid) -> PosixError { + // Running. + ASSIGN_OR_RETURN_ERRNO(fd, Open(absl::StrCat("/proc/", pid, "/task"), + O_RDONLY | O_DIRECTORY)); + + return NoError(); + }, + nullptr, nullptr)); + char buf[1024]; + ASSERT_THAT(syscall(SYS_getdents, fd.get(), buf, sizeof(buf)), + SyscallFailsWithErrno(ENOENT)); +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_hdrincl.cc b/test/syscalls/linux/raw_socket_hdrincl.cc new file mode 100644 index 000000000..a070817eb --- /dev/null +++ b/test/syscalls/linux/raw_socket_hdrincl.cc @@ -0,0 +1,408 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <linux/capability.h> +#include <netinet/in.h> +#include <netinet/ip.h> +#include <netinet/ip_icmp.h> +#include <netinet/udp.h> +#include <poll.h> +#include <sys/socket.h> +#include <sys/types.h> +#include <unistd.h> + +#include <algorithm> +#include <cstring> + +#include "gtest/gtest.h" +#include "absl/base/internal/endian.h" +#include "test/syscalls/linux/socket_test_util.h" +#include "test/syscalls/linux/unix_domain_socket_test_util.h" +#include "test/util/capability_util.h" +#include "test/util/file_descriptor.h" +#include "test/util/test_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Tests for IPPROTO_RAW raw sockets, which implies IP_HDRINCL. +class RawHDRINCL : public ::testing::Test { + protected: + // Creates a socket to be used in tests. + void SetUp() override; + + // Closes the socket created by SetUp(). + void TearDown() override; + + // Returns a valid looback IP header with no payload. + struct iphdr LoopbackHeader(); + + // Fills in buf with an IP header, UDP header, and payload. Returns false if + // buf_size isn't large enough to hold everything. + bool FillPacket(char* buf, size_t buf_size, int port, const char* payload, + uint16_t payload_size); + + // The socket used for both reading and writing. + int socket_; + + // The loopback address. + struct sockaddr_in addr_; +}; + +void RawHDRINCL::SetUp() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(socket_ = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), + SyscallSucceeds()); + + addr_ = {}; + + addr_.sin_port = IPPROTO_IP; + addr_.sin_addr.s_addr = htonl(INADDR_LOOPBACK); + addr_.sin_family = AF_INET; +} + +void RawHDRINCL::TearDown() { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + EXPECT_THAT(close(socket_), SyscallSucceeds()); +} + +struct iphdr RawHDRINCL::LoopbackHeader() { + struct iphdr hdr = {}; + hdr.ihl = 5; + hdr.version = 4; + hdr.tos = 0; + hdr.tot_len = absl::gbswap_16(sizeof(hdr)); + hdr.id = 0; + hdr.frag_off = 0; + hdr.ttl = 7; + hdr.protocol = 1; + hdr.daddr = htonl(INADDR_LOOPBACK); + // hdr.check is set by the network stack. + // hdr.tot_len is set by the network stack. + // hdr.saddr is set by the network stack. + return hdr; +} + +bool RawHDRINCL::FillPacket(char* buf, size_t buf_size, int port, + const char* payload, uint16_t payload_size) { + if (buf_size < sizeof(struct iphdr) + sizeof(struct udphdr) + payload_size) { + return false; + } + + struct iphdr ip = LoopbackHeader(); + ip.protocol = IPPROTO_UDP; + + struct udphdr udp = {}; + udp.source = absl::gbswap_16(port); + udp.dest = absl::gbswap_16(port); + udp.len = absl::gbswap_16(sizeof(udp) + payload_size); + udp.check = 0; + + memcpy(buf, reinterpret_cast<char*>(&ip), sizeof(ip)); + memcpy(buf + sizeof(ip), reinterpret_cast<char*>(&udp), sizeof(udp)); + memcpy(buf + sizeof(ip) + sizeof(udp), payload, payload_size); + + return true; +} + +// We should be able to create multiple IPPROTO_RAW sockets. RawHDRINCL::Setup +// creates the first one, so we only have to create one more here. +TEST_F(RawHDRINCL, MultipleCreation) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int s2; + ASSERT_THAT(s2 = socket(AF_INET, SOCK_RAW, IPPROTO_RAW), SyscallSucceeds()); + + ASSERT_THAT(close(s2), SyscallSucceeds()); +} + +// Test that shutting down an unconnected socket fails. +TEST_F(RawHDRINCL, FailShutdownWithoutConnect) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(shutdown(socket_, SHUT_WR), SyscallFailsWithErrno(ENOTCONN)); + ASSERT_THAT(shutdown(socket_, SHUT_RD), SyscallFailsWithErrno(ENOTCONN)); +} + +// Test that listen() fails. +TEST_F(RawHDRINCL, FailListen) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(listen(socket_, 1), SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that accept() fails. +TEST_F(RawHDRINCL, FailAccept) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr saddr; + socklen_t addrlen; + ASSERT_THAT(accept(socket_, &saddr, &addrlen), + SyscallFailsWithErrno(ENOTSUP)); +} + +// Test that the socket is writable immediately. +TEST_F(RawHDRINCL, PollWritableImmediately) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLOUT; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 0), SyscallSucceedsWithValue(1)); +} + +// Test that the socket isn't readable. +TEST_F(RawHDRINCL, NotReadable) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + // Try to receive data with MSG_DONTWAIT, which returns immediately if there's + // nothing to be read. + char buf[117]; + ASSERT_THAT(RetryEINTR(recv)(socket_, buf, sizeof(buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EINVAL)); +} + +// Test that we can connect() to a valid IP (loopback). +TEST_F(RawHDRINCL, ConnectToLoopback) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallSucceeds()); +} + +TEST_F(RawHDRINCL, SendWithoutConnectSucceeds) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct iphdr hdr = LoopbackHeader(); + ASSERT_THAT(send(socket_, &hdr, sizeof(hdr), 0), + SyscallSucceedsWithValue(sizeof(hdr))); +} + +// HDRINCL implies write-only. Verify that we can't read a packet sent to +// loopback. +TEST_F(RawHDRINCL, NotReadableAfterWrite) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallSucceeds()); + + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "odst"; + char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; + ASSERT_TRUE(FillPacket(packet, sizeof(packet), 40000 /* port */, kPayload, + sizeof(kPayload))); + + socklen_t addrlen = sizeof(addr_); + ASSERT_NO_FATAL_FAILURE( + sendto(socket_, reinterpret_cast<void*>(&packet), sizeof(packet), 0, + reinterpret_cast<struct sockaddr*>(&addr_), addrlen)); + + struct pollfd pfd = {}; + pfd.fd = socket_; + pfd.events = POLLIN; + ASSERT_THAT(RetryEINTR(poll)(&pfd, 1, 1000), SyscallSucceedsWithValue(0)); +} + +TEST_F(RawHDRINCL, WriteTooSmall) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT(connect(socket_, reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallSucceeds()); + + // This is smaller than the size of an IP header. + constexpr char kBuf[] = "JP5"; + ASSERT_THAT(send(socket_, kBuf, sizeof(kBuf), 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Bind to localhost. +TEST_F(RawHDRINCL, BindToLocalhost) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + ASSERT_THAT( + bind(socket_, reinterpret_cast<struct sockaddr*>(&addr_), sizeof(addr_)), + SyscallSucceeds()); +} + +// Bind to a different address. +TEST_F(RawHDRINCL, BindToInvalid) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr_in bind_addr = {}; + bind_addr.sin_family = AF_INET; + bind_addr.sin_addr = {1}; // 1.0.0.0 - An address that we can't bind to. + ASSERT_THAT(bind(socket_, reinterpret_cast<struct sockaddr*>(&bind_addr), + sizeof(bind_addr)), + SyscallFailsWithErrno(EADDRNOTAVAIL)); +} + +// Send and receive a packet. +TEST_F(RawHDRINCL, SendAndReceive) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int port = 40000; + if (!IsRunningOnGvisor()) { + port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( + PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); + } + + // IPPROTO_RAW sockets are write-only. We'll have to open another socket to + // read what we write. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; + char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; + ASSERT_TRUE( + FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + + socklen_t addrlen = sizeof(addr_); + ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + reinterpret_cast<struct sockaddr*>(&addr_), + addrlen)); + + // Receive the payload. + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallSucceedsWithValue(sizeof(packet))); + EXPECT_EQ( + memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), + sizeof(kPayload)), + 0); + // The network stack should have set the source address. + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); + // The packet ID should be 0, as the packet is less than 68 bytes. + struct iphdr iphdr = {}; + memcpy(&iphdr, recv_buf, sizeof(iphdr)); + EXPECT_EQ(iphdr.id, 0); +} + +// Send and receive a packet with nonzero IP ID. +TEST_F(RawHDRINCL, SendAndReceiveNonzeroID) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int port = 40000; + if (!IsRunningOnGvisor()) { + port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( + PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); + } + + // IPPROTO_RAW sockets are write-only. We'll have to open another socket to + // read what we write. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + // Construct a packet with an IP header, UDP header, and payload. Make the + // payload large enough to force an IP ID to be assigned. + constexpr char kPayload[128] = {}; + char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; + ASSERT_TRUE( + FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + + socklen_t addrlen = sizeof(addr_); + ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + reinterpret_cast<struct sockaddr*>(&addr_), + addrlen)); + + // Receive the payload. + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallSucceedsWithValue(sizeof(packet))); + EXPECT_EQ( + memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), + sizeof(kPayload)), + 0); + // The network stack should have set the source address. + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); + // The packet ID should not be 0, as the packet was more than 68 bytes. + struct iphdr* iphdr = reinterpret_cast<struct iphdr*>(recv_buf); + EXPECT_NE(iphdr->id, 0); +} + +// Send and receive a packet where the sendto address is not the same as the +// provided destination. +TEST_F(RawHDRINCL, SendAndReceiveDifferentAddress) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int port = 40000; + if (!IsRunningOnGvisor()) { + port = static_cast<short>(ASSERT_NO_ERRNO_AND_VALUE( + PortAvailable(0, AddressFamily::kIpv4, SocketType::kUdp, false))); + } + + // IPPROTO_RAW sockets are write-only. We'll have to open another socket to + // read what we write. + FileDescriptor udp_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_RAW, IPPROTO_UDP)); + + // Construct a packet with an IP header, UDP header, and payload. + constexpr char kPayload[] = "toto"; + char packet[sizeof(struct iphdr) + sizeof(struct udphdr) + sizeof(kPayload)]; + ASSERT_TRUE( + FillPacket(packet, sizeof(packet), port, kPayload, sizeof(kPayload))); + // Overwrite the IP destination address with an IP we can't get to. + struct iphdr iphdr = {}; + memcpy(&iphdr, packet, sizeof(iphdr)); + iphdr.daddr = 42; + memcpy(packet, &iphdr, sizeof(iphdr)); + + socklen_t addrlen = sizeof(addr_); + ASSERT_NO_FATAL_FAILURE(sendto(socket_, &packet, sizeof(packet), 0, + reinterpret_cast<struct sockaddr*>(&addr_), + addrlen)); + + // Receive the payload, since sendto should replace the bad destination with + // localhost. + char recv_buf[sizeof(packet)]; + struct sockaddr_in src; + socklen_t src_size = sizeof(src); + ASSERT_THAT(recvfrom(udp_sock.get(), recv_buf, sizeof(recv_buf), 0, + reinterpret_cast<struct sockaddr*>(&src), &src_size), + SyscallSucceedsWithValue(sizeof(packet))); + EXPECT_EQ( + memcmp(kPayload, recv_buf + sizeof(struct iphdr) + sizeof(struct udphdr), + sizeof(kPayload)), + 0); + // The network stack should have set the source address. + EXPECT_EQ(src.sin_family, AF_INET); + EXPECT_EQ(absl::gbswap_32(src.sin_addr.s_addr), INADDR_LOOPBACK); + // The packet ID should be 0, as the packet is less than 68 bytes. + struct iphdr recv_iphdr = {}; + memcpy(&recv_iphdr, recv_buf, sizeof(recv_iphdr)); + EXPECT_EQ(recv_iphdr.id, 0); + // The destination address should be localhost, not the bad IP we set + // initially. + EXPECT_EQ(absl::gbswap_32(recv_iphdr.daddr), INADDR_LOOPBACK); +} + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/syscalls/linux/raw_socket_icmp.cc b/test/syscalls/linux/raw_socket_icmp.cc index 7fe7c03a5..ad19120d5 100644 --- a/test/syscalls/linux/raw_socket_icmp.cc +++ b/test/syscalls/linux/raw_socket_icmp.cc @@ -195,7 +195,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { // Receive on socket 1. constexpr int kBufSize = kEmptyICMPSize; - std::vector<char[kBufSize]> recv_buf1(2); + char recv_buf1[2][kBufSize]; struct sockaddr_in src; for (int i = 0; i < 2; i++) { ASSERT_NO_FATAL_FAILURE(ReceiveICMP(recv_buf1[i], @@ -205,7 +205,7 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { } // Receive on socket 2. - std::vector<char[kBufSize]> recv_buf2(2); + char recv_buf2[2][kBufSize]; for (int i = 0; i < 2; i++) { ASSERT_NO_FATAL_FAILURE( ReceiveICMPFrom(recv_buf2[i], ABSL_ARRAYSIZE(recv_buf2[i]), @@ -221,14 +221,14 @@ TEST_F(RawSocketICMPTest, MultipleSocketReceive) { reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); return icmp->type == type; }; - const char* icmp1 = - *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type); - const char* icmp2 = - *std::find_if(recv_buf2.begin(), recv_buf2.end(), match_type); - ASSERT_NE(icmp1, *recv_buf1.end()); - ASSERT_NE(icmp2, *recv_buf2.end()); - EXPECT_EQ(memcmp(icmp1 + sizeof(struct iphdr), icmp2 + sizeof(struct iphdr), - sizeof(icmp)), + auto icmp1_it = + std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type); + auto icmp2_it = + std::find_if(std::begin(recv_buf2), std::end(recv_buf2), match_type); + ASSERT_NE(icmp1_it, std::end(recv_buf1)); + ASSERT_NE(icmp2_it, std::end(recv_buf2)); + EXPECT_EQ(memcmp(*icmp1_it + sizeof(struct iphdr), + *icmp2_it + sizeof(struct iphdr), sizeof(icmp)), 0); } } @@ -254,7 +254,7 @@ TEST_F(RawSocketICMPTest, RawAndPingSockets) { // Receive on socket 1, which receives the echo request and reply in // indeterminate order. constexpr int kBufSize = kEmptyICMPSize; - std::vector<char[kBufSize]> recv_buf1(2); + char recv_buf1[2][kBufSize]; struct sockaddr_in src; for (int i = 0; i < 2; i++) { ASSERT_NO_FATAL_FAILURE( @@ -274,11 +274,94 @@ TEST_F(RawSocketICMPTest, RawAndPingSockets) { reinterpret_cast<struct icmphdr*>(buf + sizeof(struct iphdr)); return icmp->type == ICMP_ECHOREPLY; }; - char* raw_reply = - *std::find_if(recv_buf1.begin(), recv_buf1.end(), match_type_raw); - ASSERT_NE(raw_reply, *recv_buf1.end()); + auto raw_reply_it = + std::find_if(std::begin(recv_buf1), std::end(recv_buf1), match_type_raw); + ASSERT_NE(raw_reply_it, std::end(recv_buf1)); EXPECT_EQ( - memcmp(raw_reply + sizeof(struct iphdr), ping_recv_buf, sizeof(icmp)), 0); + memcmp(*raw_reply_it + sizeof(struct iphdr), ping_recv_buf, sizeof(icmp)), + 0); +} + +// A raw ICMP socket should be able to send a malformed short ICMP Echo Request, +// while ping socket should not. +// Neither should be able to receieve a short malformed packet. +TEST_F(RawSocketICMPTest, ShortEchoRawAndPingSockets) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor ping_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); + + struct icmphdr icmp; + icmp.type = ICMP_ECHO; + icmp.code = 0; + icmp.un.echo.sequence = 0; + icmp.un.echo.id = 6789; + icmp.checksum = 0; + icmp.checksum = Checksum(&icmp); + + // Omit 2 bytes from ICMP packet. + constexpr int kShortICMPSize = sizeof(icmp) - 2; + + // Sending a malformed short ICMP message to a ping socket should fail. + ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0, + reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallFailsWithErrno(EINVAL)); + + // Sending a malformed short ICMP message to a raw socket should not fail. + ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0, + reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallSucceedsWithValue(kShortICMPSize)); + + // Neither Ping nor Raw socket should have anything to read. + char recv_buf[kEmptyICMPSize]; + EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// A raw ICMP socket should be able to send a malformed short ICMP Echo Reply, +// while ping socket should not. +// Neither should be able to receieve a short malformed packet. +TEST_F(RawSocketICMPTest, ShortEchoReplyRawAndPingSockets) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + FileDescriptor ping_sock = + ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_ICMP)); + + struct icmphdr icmp; + icmp.type = ICMP_ECHOREPLY; + icmp.code = 0; + icmp.un.echo.sequence = 0; + icmp.un.echo.id = 6789; + icmp.checksum = 0; + icmp.checksum = Checksum(&icmp); + + // Omit 2 bytes from ICMP packet. + constexpr int kShortICMPSize = sizeof(icmp) - 2; + + // Sending a malformed short ICMP message to a ping socket should fail. + ASSERT_THAT(RetryEINTR(sendto)(ping_sock.get(), &icmp, kShortICMPSize, 0, + reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallFailsWithErrno(EINVAL)); + + // Sending a malformed short ICMP message to a raw socket should not fail. + ASSERT_THAT(RetryEINTR(sendto)(s_, &icmp, kShortICMPSize, 0, + reinterpret_cast<struct sockaddr*>(&addr_), + sizeof(addr_)), + SyscallSucceedsWithValue(kShortICMPSize)); + + // Neither Ping nor Raw socket should have anything to read. + char recv_buf[kEmptyICMPSize]; + EXPECT_THAT(RetryEINTR(recv)(ping_sock.get(), recv_buf, sizeof(recv_buf), + MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); + EXPECT_THAT(RetryEINTR(recv)(s_, recv_buf, sizeof(recv_buf), MSG_DONTWAIT), + SyscallFailsWithErrno(EAGAIN)); } // Test that connect() sends packets to the right place. diff --git a/test/syscalls/linux/sigaltstack.cc b/test/syscalls/linux/sigaltstack.cc index 7d4a12c1d..69b6e4f90 100644 --- a/test/syscalls/linux/sigaltstack.cc +++ b/test/syscalls/linux/sigaltstack.cc @@ -175,7 +175,7 @@ TEST(SigaltstackTest, WalksOffBottom) { // Trigger a single fault. badhandler_low_water_mark = - reinterpret_cast<char*>(&stack.ss_sp) + SIGSTKSZ; // Expected top. + static_cast<char*>(stack.ss_sp) + SIGSTKSZ; // Expected top. badhandler_recursive_faults = 0; // Disable refault. Fault(); EXPECT_TRUE(badhandler_on_sigaltstack); diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index 4597e91e3..8d77431f2 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -890,6 +890,61 @@ TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) { EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax))); } +TEST_P(SimpleTcpSocketTest, MaxSegDefault) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + constexpr int kDefaultMSS = 536; + int tcp_max_seg; + socklen_t optlen = sizeof(tcp_max_seg); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, &optlen), + SyscallSucceedsWithValue(0)); + + EXPECT_EQ(kDefaultMSS, tcp_max_seg); + EXPECT_EQ(sizeof(tcp_max_seg), optlen); +} + +TEST_P(SimpleTcpSocketTest, SetMaxSeg) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + constexpr int kDefaultMSS = 536; + constexpr int kTCPMaxSeg = 1024; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &kTCPMaxSeg, + sizeof(kTCPMaxSeg)), + SyscallSucceedsWithValue(0)); + + // Linux actually never returns the user_mss value. It will always return the + // default MSS value defined above for an unconnected socket and always return + // the actual current MSS for a connected one. + int optval; + socklen_t optlen = sizeof(optval); + ASSERT_THAT(getsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &optval, &optlen), + SyscallSucceedsWithValue(0)); + + EXPECT_EQ(kDefaultMSS, optval); + EXPECT_EQ(sizeof(optval), optlen); +} + +TEST_P(SimpleTcpSocketTest, SetMaxSegFailsForInvalidMSSValues) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + + { + constexpr int tcp_max_seg = 10; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, + sizeof(tcp_max_seg)), + SyscallFailsWithErrno(EINVAL)); + } + { + constexpr int tcp_max_seg = 75000; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_MAXSEG, &tcp_max_seg, + sizeof(tcp_max_seg)), + SyscallFailsWithErrno(EINVAL)); + } +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); |