diff options
Diffstat (limited to 'pkg')
34 files changed, 412 insertions, 120 deletions
diff --git a/pkg/sentry/fs/dev/tty.go b/pkg/sentry/fs/dev/tty.go index b4c2a62fd..87d80e292 100644 --- a/pkg/sentry/fs/dev/tty.go +++ b/pkg/sentry/fs/dev/tty.go @@ -32,6 +32,7 @@ type ttyInodeOperations struct { fsutil.InodeNoopWriteOut `state:"nosave"` fsutil.InodeNotDirectory `state:"nosave"` fsutil.InodeNotMappable `state:"nosave"` + fsutil.InodeNotOpenable `state:"nosave"` fsutil.InodeNotSocket `state:"nosave"` fsutil.InodeNotSymlink `state:"nosave"` fsutil.InodeVirtual `state:"nosave"` @@ -47,11 +48,6 @@ func newTTYDevice(ctx context.Context, owner fs.FileOwner, mode linux.FileMode) } } -// GetFile implements fs.InodeOperations.GetFile. -func (*ttyInodeOperations) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { - return fs.NewFile(ctx, dirent, flags, &ttyFileOperations{}), nil -} - // +stateify savable type ttyFileOperations struct { fsutil.FileNoSeek `state:"nosave"` diff --git a/pkg/sentry/fs/inode_overlay.go b/pkg/sentry/fs/inode_overlay.go index 24b769cfc..e0602da17 100644 --- a/pkg/sentry/fs/inode_overlay.go +++ b/pkg/sentry/fs/inode_overlay.go @@ -339,7 +339,9 @@ func overlayRemove(ctx context.Context, o *overlayEntry, parent *Dirent, child * } } if child.Inode.overlay.lowerExists { - return overlayCreateWhiteout(o.upper, child.name) + if err := overlayCreateWhiteout(o.upper, child.name); err != nil { + return err + } } // We've removed from the directory so we must drop the cache. o.markDirectoryDirty() @@ -418,10 +420,12 @@ func overlayRename(ctx context.Context, o *overlayEntry, oldParent *Dirent, rena return err } if renamed.Inode.overlay.lowerExists { - return overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName) + if err := overlayCreateWhiteout(oldParent.Inode.overlay.upper, oldName); err != nil { + return err + } } // We've changed the directory so we must drop the cache. - o.markDirectoryDirty() + oldParent.Inode.overlay.markDirectoryDirty() return nil } diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 37694620c..6b839685b 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -155,37 +155,40 @@ func (n *netDev) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]se contents[1] = " face |bytes packets errs drop fifo frame compressed multicast|bytes packets errs drop fifo colls carrier compressed\n" for _, i := range interfaces { - // TODO(b/71872867): Collect stats from each inet.Stack - // implementation (hostinet, epsocket, and rpcinet). - // Implements the same format as // net/core/net-procfs.c:dev_seq_printf_stats. - l := fmt.Sprintf("%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", + var stats inet.StatDev + if err := n.s.Statistics(&stats, i.Name); err != nil { + log.Warningf("Failed to retrieve interface statistics for %v: %v", i.Name, err) + continue + } + l := fmt.Sprintf( + "%6s: %7d %7d %4d %4d %4d %5d %10d %9d %8d %7d %4d %4d %4d %5d %7d %10d\n", i.Name, // Received - 0, // bytes - 0, // packets - 0, // errors - 0, // dropped - 0, // fifo - 0, // frame - 0, // compressed - 0, // multicast + stats[0], // bytes + stats[1], // packets + stats[2], // errors + stats[3], // dropped + stats[4], // fifo + stats[5], // frame + stats[6], // compressed + stats[7], // multicast // Transmitted - 0, // bytes - 0, // packets - 0, // errors - 0, // dropped - 0, // fifo - 0, // frame - 0, // compressed - 0) // multicast + stats[8], // bytes + stats[9], // packets + stats[10], // errors + stats[11], // dropped + stats[12], // fifo + stats[13], // frame + stats[14], // compressed + stats[15]) // multicast contents = append(contents, l) } var data []seqfile.SeqData for _, l := range contents { - data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*ifinet6)(nil)}) + data = append(data, seqfile.SeqData{Buf: []byte(l), Handle: (*netDev)(nil)}) } return data, 0 diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index ef0ca3301..87184ec67 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -162,6 +162,11 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry // subtask to emit. offset := file.Offset() + tasks := f.t.ThreadGroup().MemberIDs(f.pidns) + if len(tasks) == 0 { + return offset, syserror.ENOENT + } + if offset == 0 { // Serialize "." and "..". root := fs.RootFromContext(ctx) @@ -178,12 +183,12 @@ func (f *subtasksFile) Readdir(ctx context.Context, file *fs.File, ser fs.Dentry } // Serialize tasks. - tasks := f.t.ThreadGroup().MemberIDs(f.pidns) taskInts := make([]int, 0, len(tasks)) for _, tid := range tasks { taskInts = append(taskInts, int(tid)) } + sort.Sort(sort.IntSlice(taskInts)) // Find the task to start at. idx := sort.SearchInts(taskInts, int(offset)) if idx == len(taskInts) { diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 7c104fd47..5b75a4a06 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -49,6 +49,9 @@ type Stack interface { // SetTCPSACKEnabled attempts to change TCP selective acknowledgement // settings. SetTCPSACKEnabled(enabled bool) error + + // Statistics reports stack statistics. + Statistics(stat interface{}, arg string) error } // Interface contains information about a network interface. @@ -102,3 +105,7 @@ type TCPBufferSize struct { // Max is the maximum size. Max int } + +// StatDev describes one line of /proc/net/dev, i.e., stats for one network +// interface. +type StatDev [16]uint64 diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 624371eb6..75f9e7a77 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -81,3 +81,8 @@ func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { s.TCPSACKFlag = enabled return nil } + +// Statistics implements inet.Stack.Statistics. +func (s *TestStack) Statistics(stat interface{}, arg string) error { + return nil +} diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 8d76e106e..405e00292 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -21,6 +21,7 @@ package kvm import ( "fmt" + "math" "sync/atomic" "syscall" "unsafe" @@ -134,7 +135,7 @@ func (c *vCPU) notify() { syscall.SYS_FUTEX, uintptr(unsafe.Pointer(&c.state)), linux.FUTEX_WAKE|linux.FUTEX_PRIVATE_FLAG, - ^uintptr(0), // Number of waiters. + math.MaxInt32, // Number of waiters. 0, 0, 0) if errno != 0 { throw("futex wake error") diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 9d1bcfd41..69eff7373 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -867,6 +867,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa return int32(v), nil + case linux.TCP_MAXSEG: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.MaxSegOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return int32(v), nil + case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1219,6 +1231,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + case linux.TCP_MAXSEG: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index 6d2b5d038..421f93dc4 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -40,42 +40,49 @@ type provider struct { } // getTransportProtocol figures out transport protocol. Currently only TCP, -// UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +// UDP, and ICMP are supported. The bool return value is true when this socket +// is associated with a transport protocol. This is only false for SOCK_RAW, +// IPPROTO_IP sockets. +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { - return 0, syserr.ErrInvalidArgument + return 0, true, syserr.ErrInvalidArgument } - return tcp.ProtocolNumber, nil + return tcp.ProtocolNumber, true, nil case linux.SOCK_DGRAM: switch protocol { case 0, syscall.IPPROTO_UDP: - return udp.ProtocolNumber, nil + return udp.ProtocolNumber, true, nil case syscall.IPPROTO_ICMP: - return header.ICMPv4ProtocolNumber, nil + return header.ICMPv4ProtocolNumber, true, nil case syscall.IPPROTO_ICMPV6: - return header.ICMPv6ProtocolNumber, nil + return header.ICMPv6ProtocolNumber, true, nil } case linux.SOCK_RAW: // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, syserr.ErrPermissionDenied + return 0, true, syserr.ErrPermissionDenied } switch protocol { case syscall.IPPROTO_ICMP: - return header.ICMPv4ProtocolNumber, nil + return header.ICMPv4ProtocolNumber, true, nil case syscall.IPPROTO_UDP: - return header.UDPProtocolNumber, nil + return header.UDPProtocolNumber, true, nil case syscall.IPPROTO_TCP: - return header.TCPProtocolNumber, nil + return header.TCPProtocolNumber, true, nil + // IPPROTO_RAW signifies that the raw socket isn't assigned to + // a transport protocol. Users will be able to write packets' + // IP headers and won't receive anything. + case syscall.IPPROTO_RAW: + return tcpip.TransportProtocolNumber(0), false, nil } } - return 0, syserr.ErrProtocolNotSupported + return 0, true, syserr.ErrProtocolNotSupported } // Socket creates a new socket object for the AF_INET or AF_INET6 family. @@ -93,7 +100,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* } // Figure out the transport protocol. - transProto, err := getTransportProtocol(t, stype, protocol) + transProto, associated, err := getTransportProtocol(t, stype, protocol) if err != nil { return nil, err } @@ -103,7 +110,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* var e *tcpip.Error wq := &waiter.Queue{} if stype == linux.SOCK_RAW { - ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq) + ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) } diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go index 1627a4f68..7eef19f74 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/epsocket/stack.go @@ -138,3 +138,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserr.ErrEndpointOperation.ToError() +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 11f94281c..cc1f66fa1 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -244,3 +244,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserror.EOPNOTSUPP +} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go index 3038f25a7..49bd3a220 100644 --- a/pkg/sentry/socket/rpcinet/stack.go +++ b/pkg/sentry/socket/rpcinet/stack.go @@ -133,3 +133,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { func (s *Stack) SetTCPSACKEnabled(enabled bool) error { panic("rpcinet handles procfs directly this method should not be called") } + +// Statistics implements inet.Stack.Statistics. +func (s *Stack) Statistics(stat interface{}, arg string) error { + return syserr.ErrEndpointOperation.ToError() +} diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 680c0c64c..51db2d8f7 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -234,9 +234,9 @@ var AMD64 = &kernel.SyscallTable{ 185: syscalls.Error("security", syserror.ENOSYS, "Not implemented in Linux.", nil), 186: syscalls.Supported("gettid", Gettid), 187: syscalls.ErrorWithEvent("readahead", syserror.ENOSYS, "", []string{"gvisor.dev/issue/261"}), // TODO(b/29351341) - 188: syscalls.ErrorWithEvent("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 189: syscalls.ErrorWithEvent("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), - 190: syscalls.ErrorWithEvent("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 188: syscalls.Error("setxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 189: syscalls.Error("lsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), + 190: syscalls.Error("fsetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 191: syscalls.ErrorWithEvent("getxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 192: syscalls.ErrorWithEvent("lgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), 193: syscalls.ErrorWithEvent("fgetxattr", syserror.ENOTSUP, "Requires filesystem support.", nil), diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index c081de61f..c52c0d851 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -24,15 +24,11 @@ import ( type ICMPv4 []byte const ( - // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. - ICMPv4MinimumSize = 4 - - // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet. - ICMPv4EchoMinimumSize = 6 + // ICMPv4PayloadOffset defines the start of ICMP payload. + ICMPv4PayloadOffset = 4 - // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP - // destination unreachable packet. - ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4 + // ICMPv4MinimumSize is the minimum size of a valid ICMP packet. + ICMPv4MinimumSize = 8 // ICMPv4ProtocolNumber is the ICMP transport protocol number. ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1 @@ -104,5 +100,5 @@ func (ICMPv4) SetDestinationPort(uint16) { // Payload implements Transport.Payload. func (b ICMPv4) Payload() []byte { - return b[ICMPv4MinimumSize:] + return b[ICMPv4PayloadOffset:] } diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 7da4c4845..94a3af289 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -85,6 +85,10 @@ const ( // units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // MinIPFragmentPayloadSize is the minimum number of payload bytes that + // the first fragment must carry when an IPv4 packet is fragmented. + MinIPFragmentPayloadSize = 8 + // IPv4AddressSize is the size, in bytes, of an IPv4 address. IPv4AddressSize = 4 @@ -268,6 +272,10 @@ func (b IPv4) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv4Version { + return false + } + return true } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index 7163eaa36..95fe8bfc3 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -184,6 +184,10 @@ func (b IPv6) IsValid(pktSize int) bool { return false } + if IPVersion(b) != IPv6Version { + return false + } + return true } diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go index 1141443bb..82cfe785c 100644 --- a/pkg/tcpip/header/tcp.go +++ b/pkg/tcpip/header/tcp.go @@ -176,6 +176,21 @@ const ( // TCPProtocolNumber is TCP's transport protocol number. TCPProtocolNumber tcpip.TransportProtocolNumber = 6 + + // TCPMinimumMSS is the minimum acceptable value for MSS. This is the + // same as the value TCP_MIN_MSS defined net/tcp.h. + TCPMinimumMSS = IPv4MaximumHeaderSize + TCPHeaderMaximumSize + MinIPFragmentPayloadSize - IPv4MinimumSize - TCPMinimumSize + + // TCPMaximumMSS is the maximum acceptable value for MSS. + TCPMaximumMSS = 0xffff + + // TCPDefaultMSS is the MSS value that should be used if an MSS option + // is not received from the peer. It's also the value returned by + // TCP_MAXSEG option for a socket in an unconnected state. + // + // Per RFC 1122, page 85: "If an MSS option is not received at + // connection setup, TCP MUST assume a default send MSS of 536." + TCPDefaultMSS = 536 ) // SourcePort returns the "source port" field of the tcp header. @@ -306,7 +321,7 @@ func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions { synOpts := TCPSynOptions{ // Per RFC 1122, page 85: "If an MSS option is not received at // connection setup, TCP MUST assume a default send MSS of 536." - MSS: 536, + MSS: TCPDefaultMSS, // If no window scale option is specified, WS in options is // returned as -1; this is because the absence of the option // indicates that the we cannot use window scaling on the diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index ca3d6c0bf..cb35635fc 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -83,6 +83,10 @@ func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buf return tcpip.ErrNotSupported } +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { v := vv.First() h := header.ARP(v) diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index db65ee7cc..8ff428445 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -282,10 +282,10 @@ func TestIPv4ReceiveControl(t *testing.T) { {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10}, {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8}, {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8}, - {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8}, + {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4MinimumSize + header.IPv4MinimumSize + 8}, {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0}, - {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8}, + {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4MinimumSize + 8}, } r, err := buildIPv4Route(localIpv4Addr, "\x0a\x00\x00\xbb") if err != nil { @@ -301,7 +301,7 @@ func TestIPv4ReceiveControl(t *testing.T) { } defer ep.Close() - const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4 + const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize view := buffer.NewView(dataOffset + 8) // Create the outer IPv4 header. @@ -319,10 +319,10 @@ func TestIPv4ReceiveControl(t *testing.T) { icmp := header.ICMPv4(view[header.IPv4MinimumSize:]) icmp.SetType(header.ICMPv4DstUnreachable) icmp.SetCode(c.code) - copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef}) + copy(view[header.IPv4MinimumSize+header.ICMPv4PayloadOffset:], []byte{0xde, 0xad, 0xbe, 0xef}) // Create the inner IPv4 header. - ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:]) + ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:]) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: 100, diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index bc7f1c42a..fbef6947d 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -68,10 +68,6 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } // Only send a reply if the checksum is valid. wantChecksum := h.Checksum() @@ -93,9 +89,9 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) vv := vv.Clone(nil) - vv.TrimFront(header.ICMPv4EchoMinimumSize) - hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4EchoMinimumSize) - pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + vv.TrimFront(header.ICMPv4MinimumSize) + hdr := buffer.NewPrependable(int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(pkt, h) pkt.SetType(header.ICMPv4EchoReply) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) @@ -108,25 +104,19 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V case header.ICMPv4EchoReply: received.EchoReply.Increment() - if len(v) < header.ICMPv4EchoMinimumSize { - received.Invalid.Increment() - return - } + e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, netHeader, vv) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() - if len(v) < header.ICMPv4DstUnreachableMinimumSize { - received.Invalid.Increment() - return - } - vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize) + + vv.TrimFront(header.ICMPv4MinimumSize) switch h.Code() { case header.ICMPv4PortUnreachable: e.handleControl(stack.ControlPortUnreachable, 0, vv) case header.ICMPv4FragmentationNeeded: - mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:])) + mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset+2:])) e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 1e3a7425a..e44a73d96 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -232,6 +232,55 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return nil } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // The packet already has an IP header, but there are a few required + // checks. + ip := header.IPv4(payload.First()) + if !ip.IsValid(payload.Size()) { + return tcpip.ErrInvalidOptionValue + } + + // Always set the total length. + ip.SetTotalLength(uint16(payload.Size())) + + // Set the source address when zero. + if ip.SourceAddress() == tcpip.Address(([]byte{0, 0, 0, 0})) { + ip.SetSourceAddress(r.LocalAddress) + } + + // Set the destination. If the packet already included a destination, + // it will be part of the route. + ip.SetDestinationAddress(r.RemoteAddress) + + // Set the packet ID when zero. + if ip.ID() == 0 { + id := uint32(0) + if payload.Size() > header.IPv4MaximumHeaderSize+8 { + // Packets of 68 bytes or less are required by RFC 791 to not be + // fragmented, so we only assign ids to larger packets. + id = atomic.AddUint32(&ids[hashRoute(r, 0 /* protocol */)%buckets], 1) + } + ip.SetID(uint16(id)) + } + + // Always set the checksum. + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + if loop&stack.PacketLoop != 0 { + e.HandlePacket(r, payload) + } + if loop&stack.PacketOut == 0 { + return nil + } + + hdr := buffer.NewPrependableFromView(payload.ToView()) + r.Stats().IP.PacketsSent.Increment() + return e.linkEP.WritePacket(r, nil /* gso */, hdr, buffer.VectorisedView{}, ProtocolNumber) +} + // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 27367d6c5..e3e8739fd 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -120,6 +120,13 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prepen return e.linkEP.WritePacket(r, gso, hdr, payload, ProtocolNumber) } +// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. It is not yet +// supported by IPv6. +func (*endpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + // TODO(b/119580726): Support IPv6 header-included packets. + return tcpip.ErrNotSupported +} + // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. func (e *endpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 0ecaa0833..462265281 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -174,6 +174,10 @@ type NetworkEndpoint interface { // protocol. WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + // WriteHeaderIncludedPacket writes a packet that includes a network + // header to the given destination address. + WriteHeaderIncludedPacket(r *Route, payload buffer.VectorisedView, loop PacketLooping) *tcpip.Error + // ID returns the network protocol endpoint ID. ID() *NetworkEndpointID @@ -357,10 +361,19 @@ type TransportProtocolFactory func() TransportProtocol // instantiate network protocols. type NetworkProtocolFactory func() NetworkProtocol +// UnassociatedEndpointFactory produces endpoints for writing packets not +// associated with a particular transport protocol. Such endpoints can be used +// to write arbitrary packets that include the IP header. +type UnassociatedEndpointFactory interface { + NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) +} + var ( transportProtocols = make(map[string]TransportProtocolFactory) networkProtocols = make(map[string]NetworkProtocolFactory) + unassociatedFactory UnassociatedEndpointFactory + linkEPMu sync.RWMutex nextLinkEndpointID tcpip.LinkEndpointID = 1 linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) @@ -380,6 +393,13 @@ func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { networkProtocols[name] = p } +// RegisterUnassociatedFactory registers a factory to produce endpoints not +// associated with any particular transport protocol. This function is intended +// to be called by init() functions of the protocols. +func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { + unassociatedFactory = f +} + // RegisterLinkEndpoint register a link-layer protocol endpoint and returns an // ID that can be used to refer to it. func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 36d7b6ac7..391ab4344 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -163,6 +163,18 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec return err } +// WriteHeaderIncludedPacket writes a packet already containing a network +// header through the given route. +func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + r.ref.nic.stats.Tx.Packets.Increment() + r.ref.nic.stats.Tx.Bytes.IncrementBy(uint64(payload.Size())) + return nil +} + // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { return r.ref.ep.DefaultTTL() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 2d7f56ca9..3e8fb2a6c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -340,6 +340,8 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + unassociatedFactory UnassociatedEndpointFactory + demux *transportDemuxer stats tcpip.Stats @@ -442,6 +444,8 @@ func New(network []string, transport []string, opts Options) *Stack { } } + s.unassociatedFactory = unassociatedFactory + // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -574,11 +578,15 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // NewRawEndpoint creates a new raw transport layer endpoint of the given // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. -func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { +func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if !s.raw { return nil, tcpip.ErrNotPermitted } + if !associated { + return s.unassociatedFactory.NewUnassociatedRawEndpoint(s, network, transport, waiterQueue) + } + t, ok := s.transportProtocols[transport] if !ok { return nil, tcpip.ErrUnknownProtocol diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 69884af03..959071dbe 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -137,6 +137,10 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) } +func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { + return tcpip.ErrNotSupported +} + func (*fakeNetworkEndpoint) Close() {} type fakeNetGoodOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c61f96fb0..c4076666a 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -496,6 +496,10 @@ type AvailableCongestionControlOption string // buffer moderation. type ModerateReceiveBufferOption bool +// MaxSegOption is used by SetSockOpt/GetSockOpt to set/get the current +// Maximum Segment Size(MSS) value as specified using the TCP_MAXSEG option. +type MaxSegOption int + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index ab9e80747..a80ceafd0 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -291,7 +291,7 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c switch e.netProto { case header.IPv4ProtocolNumber: - err = e.send4(route, v) + err = send4(route, e.id.LocalPort, v) case header.IPv6ProtocolNumber: err = send6(route, e.id.LocalPort, v) @@ -352,20 +352,20 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } -func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { - if len(data) < header.ICMPv4EchoMinimumSize { +func send4(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv4MinimumSize { return tcpip.ErrInvalidEndpointState } // Set the ident to the user-specified port. Sequence number should // already be set by the user. - binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) + binary.BigEndian.PutUint16(data[header.ICMPv4PayloadOffset:], ident) - hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) - icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) copy(icmpv4, data) - data = data[header.ICMPv4EchoMinimumSize:] + data = data[header.ICMPv4MinimumSize:] // Linux performs these basic checks. if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index c89538131..7fdba5d56 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -90,19 +90,18 @@ func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProt func (p *protocol) MinimumPacketSize() int { switch p.number { case ProtocolNumber4: - return header.ICMPv4EchoMinimumSize + return header.ICMPv4MinimumSize case ProtocolNumber6: return header.ICMPv6EchoMinimumSize } panic(fmt.Sprint("unknown protocol number: ", p.number)) } -// ParsePorts returns the source and destination ports stored in the given icmp -// packet. +// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { switch p.number { case ProtocolNumber4: - return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil + return 0, binary.BigEndian.Uint16(v[header.ICMPv4PayloadOffset:]), nil case ProtocolNumber6: return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil } diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD index 34a14bf7f..bc4b255b4 100644 --- a/pkg/tcpip/transport/raw/BUILD +++ b/pkg/tcpip/transport/raw/BUILD @@ -21,6 +21,7 @@ go_library( "endpoint.go", "endpoint_state.go", "packet_list.go", + "protocol.go", ], importpath = "gvisor.dev/gvisor/pkg/tcpip/transport/raw", imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 42aded77f..a29587658 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -67,6 +67,7 @@ type endpoint struct { netProto tcpip.NetworkProtocolNumber transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue + associated bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -97,8 +98,12 @@ type endpoint struct { } // NewEndpoint returns a raw endpoint for the given protocols. -// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +// TODO(b/129292371): IP_HDRINCL and AF_PACKET. func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { if netProto != header.IPv4ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -110,6 +115,16 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans waiterQueue: waiterQueue, rcvBufSizeMax: 32 * 1024, sndBufSize: 32 * 1024, + associated: associated, + } + + // Unassociated endpoints are write-only and users call Write() with IP + // headers included. Because they're write-only, We don't need to + // register with the stack. + if !associated { + ep.rcvBufSizeMax = 0 + ep.waiterQueue = nil + return ep, nil } if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { @@ -124,7 +139,7 @@ func (ep *endpoint) Close() { ep.mu.Lock() defer ep.mu.Unlock() - if ep.closed { + if ep.closed || !ep.associated { return } @@ -142,8 +157,11 @@ func (ep *endpoint) Close() { if ep.connected { ep.route.Release() + ep.connected = false } + ep.closed = true + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -152,6 +170,10 @@ func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if !ep.associated { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue + } + ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -192,6 +214,33 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + // If this is an unassociated socket and callee provided a nonzero + // destination address, route using that address. + if !ep.associated { + ip := header.IPv4(payloadBytes) + if !ip.IsValid(payload.Size()) { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidOptionValue + } + dstAddr := ip.DestinationAddress() + // Update dstAddr with the address in the IP header, unless + // opts.To is set (e.g. if sendto specifies a specific + // address). + if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { + opts.To = &tcpip.FullAddress{ + NIC: 0, // NIC is unset. + Addr: dstAddr, // The address from the payload. + Port: 0, // There are no ports here. + } + } + } + // Did the user caller provide a destination? If not, use the connected // destination. if opts.To == nil { @@ -216,12 +265,12 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, tcpip.ErrInvalidEndpointState } - n, ch, err := ep.finishWrite(payload, savedRoute) + n, ch, err := ep.finishWrite(payloadBytes, savedRoute) ep.mu.Unlock() return n, ch, err } - n, ch, err := ep.finishWrite(payload, &ep.route) + n, ch, err := ep.finishWrite(payloadBytes, &ep.route) ep.mu.RUnlock() return n, ch, err } @@ -248,7 +297,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp return 0, nil, err } - n, ch, err := ep.finishWrite(payload, &route) + n, ch, err := ep.finishWrite(payloadBytes, &route) route.Release() ep.mu.RUnlock() return n, ch, err @@ -256,7 +305,7 @@ func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintp // finishWrite writes the payload to a route. It resolves the route if // necessary. It's really just a helper to make defer unnecessary in Write. -func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { +func (ep *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { // We may need to resolve the route (match a link layer address to the // network address). If that requires blocking (e.g. to use ARP), // return a channel on which the caller can wait. @@ -269,13 +318,14 @@ func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uint } } - payloadBytes, err := payload.Get(payload.Size()) - if err != nil { - return 0, nil, err - } - switch ep.netProto { case header.IPv4ProtocolNumber: + if !ep.associated { + if err := route.WriteHeaderIncludedPacket(buffer.View(payloadBytes).ToVectorisedView()); err != nil { + return 0, nil, err + } + break + } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { return 0, nil, err @@ -335,15 +385,17 @@ func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } defer route.Release() - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = nic } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - // Save the route and NIC we've connected via. + // Save the route we've connected via. ep.route = route.Clone() - ep.registeredNIC = nic ep.connected = true return nil @@ -386,14 +438,16 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrBadLocalAddress } - // Re-register the endpoint with the appropriate NIC. - if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { - return err + if ep.associated { + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC } - ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) - ep.registeredNIC = addr.NIC - ep.boundNIC = addr.NIC ep.boundAddr = addr.Addr ep.bound = true diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go new file mode 100644 index 000000000..783c21e6b --- /dev/null +++ b/pkg/tcpip/transport/raw/protocol.go @@ -0,0 +1,32 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +type factory struct{} + +// NewUnassociatedRawEndpoint implements stack.UnassociatedEndpointFactory. +func (factory) NewUnassociatedRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) +} + +func init() { + stack.RegisterUnassociatedFactory(factory{}) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index cb40fea94..beb90afb5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -117,6 +117,7 @@ const ( notifyDrain notifyReset notifyKeepaliveChanged + notifyMSSChanged ) // SACKInfo holds TCP SACK related information for a given endpoint. @@ -218,8 +219,6 @@ type endpoint struct { mu sync.RWMutex `state:"nosave"` id stack.TransportEndpointID - // state endpointState `state:".(endpointState)"` - // pState ProtocolState state EndpointState `state:".(EndpointState)"` isPortReserved bool `state:"manual"` @@ -313,6 +312,10 @@ type endpoint struct { // in SYN-RCVD state. synRcvdCount int + // userMSS if non-zero is the MSS value explicitly set by the user + // for this endpoint using the TCP_MAXSEG setsockopt. + userMSS int + // The following fields are used to manage the send buffer. When // segments are ready to be sent, they are added to sndQueue and the // protocol goroutine is signaled via sndWaker. @@ -917,6 +920,17 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { } return nil + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.mu.Lock() + e.userMSS = int(userMSS) + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyMSSChanged) + return nil + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. @@ -1096,6 +1110,14 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.lastErrorMu.Unlock() return err + case *tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + *o = header.TCPDefaultMSS + return nil + case *tcpip.SendBufferSizeOption: e.sndBufMu.Lock() *o = tcpip.SendBufferSizeOption(e.sndBufSize) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 630dd7925..bcc0f3e28 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -271,7 +271,7 @@ func (c *Context) GetPacketNonBlocking() []byte { // SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { // Allocate a buffer data and headers. - buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p1) + len(p2)) if len(buf) > maxTotalSize { buf = buf[:maxTotalSize] } @@ -291,8 +291,8 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byt icmp.SetType(typ) icmp.SetCode(code) - copy(icmp[header.ICMPv4MinimumSize:], p1) - copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + copy(icmp[header.ICMPv4PayloadOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset+len(p1):], p2) // Inject packet. c.linkEP.Inject(ipv4.ProtocolNumber, buf.ToVectorisedView()) |