diff options
Diffstat (limited to 'pkg/tcpip/transport')
22 files changed, 1428 insertions, 534 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 8ce294002..4612be4e7 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -344,6 +344,10 @@ func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + } return nil } @@ -744,15 +748,15 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv4MinimumSize) - if !ok || header.ICMPv4(h).Type() != header.ICMPv4EchoReply { + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return } case header.IPv6ProtocolNumber: - h, ok := pkt.Data.PullUp(header.ICMPv6MinimumSize) - if !ok || header.ICMPv6(h).Type() != header.ICMPv6EchoReply { + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() return @@ -786,12 +790,14 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk }, } - packet.data = pkt.Data + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() e.stats.PacketsReceived.Increment() diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index baf08eda6..df478115d 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -25,6 +25,8 @@ package packet import ( + "fmt" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -43,6 +45,9 @@ type packet struct { timestampNS int64 // senderAddr is the network address of the sender. senderAddr tcpip.FullAddress + // packetInfo holds additional information like the protocol + // of the packet etc. + packetInfo tcpip.LinkPacketInfo } // endpoint is the packet socket implementation of tcpip.Endpoint. It is legal @@ -71,11 +76,17 @@ type endpoint struct { rcvClosed bool // The following fields are protected by mu. - mu sync.RWMutex `state:"nosave"` - sndBufSize int - closed bool - stats tcpip.TransportEndpointStats `state:"nosave"` - bound bool + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -92,6 +103,17 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb sndBufSize: 32 * 1024, } + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { return nil, err } @@ -132,8 +154,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.PacketEndpoint.ReadPacket. +func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -158,11 +180,20 @@ func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMes *addr = packet.senderAddr } + if info != nil { + *info = packet.packetInfo + } + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil } +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + return ep.ReadPacket(addr, nil) +} + func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - // TODO(b/129292371): Implement. + // TODO(gvisor.dev/issue/173): Implement. return 0, nil, tcpip.ErrInvalidOptionValue } @@ -215,12 +246,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -228,6 +261,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } @@ -264,7 +298,13 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // used with SetSockOpt, and this function always returns // tcpip.ErrNotSupported. func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. @@ -274,11 +314,63 @@ func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } @@ -289,7 +381,32 @@ func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { - return 0, tcpip.ErrNotSupported + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } } // HandlePacket implements stack.PacketEndpoint.HandlePacket. @@ -315,7 +432,7 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, // Push new packet into receive list and increment the buffer size. var packet packet - // TODO(b/129292371): Return network protocol. + // TODO(gvisor.dev/issue/173): Return network protocol. if len(pkt.LinkHeader) > 0 { // Get info directly from the ethernet header. hdr := header.Ethernet(pkt.LinkHeader) @@ -323,40 +440,66 @@ func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, NIC: nicID, Addr: tcpip.Address(hdr.SourceAddress()), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } else { // Guess the would-be ethernet header. packet.senderAddr = tcpip.FullAddress{ NIC: nicID, Addr: tcpip.Address(localAddr), } + packet.packetInfo.Protocol = netProto + packet.packetInfo.PktType = pkt.PktType } if ep.cooked { // Cooked packets can simply be queued. - packet.data = pkt.Data + switch pkt.PktType { + case tcpip.PacketHost: + packet.data = pkt.Data + case tcpip.PacketOutgoing: + // Strip Link Header from the Header. + pkt.Header = buffer.NewPrependableFromView(pkt.Header.View()[len(pkt.LinkHeader):]) + combinedVV := pkt.Header.View().ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + default: + panic(fmt.Sprintf("unexpected PktType in pkt: %+v", pkt)) + } + } else { // Raw packets need their ethernet headers prepended before // queueing. var linkHeader buffer.View - if len(pkt.LinkHeader) == 0 { - // We weren't provided with an actual ethernet header, - // so fake one. - ethFields := header.EthernetFields{ - SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), - DstAddr: localAddr, - Type: netProto, + var combinedVV buffer.VectorisedView + if pkt.PktType != tcpip.PacketOutgoing { + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) } - fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) - fakeHeader.Encode(ðFields) - linkHeader = buffer.View(fakeHeader) - } else { - linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + combinedVV = linkHeader.ToVectorisedView() + } + if pkt.PktType == tcpip.PacketOutgoing { + // For outgoing packets the Link, Network and Transport + // headers are in the pkt.Header fields normally unless + // a Raw socket is in use in which case pkt.Header could + // be nil. + combinedVV.AppendView(pkt.Header.View()) } - combinedVV := linkHeader.ToVectorisedView() combinedVV.Append(pkt.Data) packet.data = combinedVV } - packet.timestampNS = ep.stack.NowNanoseconds() + packet.timestampNS = ep.stack.Clock().NowNanoseconds() ep.rcvList.PushBack(&packet) ep.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 6a7977259..f85a68554 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -63,6 +63,7 @@ type endpoint struct { stack *stack.Stack `state:"manual"` waiterQueue *waiter.Queue associated bool + hdrIncluded bool // The following fields are used to manage the receive queue and are // protected by rcvMu. @@ -94,7 +95,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if netProto != header.IPv4ProtocolNumber { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { return nil, tcpip.ErrUnknownProtocol } @@ -108,16 +109,17 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProt rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, associated: associated, + hdrIncluded: !associated, } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(transProto, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(transProto, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -182,10 +184,6 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { // Read implements tcpip.Endpoint.Read. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if !e.associated { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidOptionValue - } - e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -215,6 +213,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess // Write implements tcpip.Endpoint.Write. func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + n, ch, err := e.write(p, opts) switch err { case nil: @@ -258,7 +261,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If this is an unassociated socket and callee provided a nonzero // destination address, route using that address. - if !e.associated { + if e.hdrIncluded { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { e.mu.RUnlock() @@ -319,12 +322,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return 0, nil, tcpip.ErrNoRoute } - // We don't support IPv6 yet, so this has to be an IPv4 address. - if len(opts.To.Addr) != header.IPv4AddressSize { - e.mu.RUnlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - // Find the route to the destination. If BindAddress is 0, // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) @@ -354,17 +351,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, } } - switch e.NetProto { - case header.IPv4ProtocolNumber: - if !e.associated { - if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ - Data: buffer.View(payloadBytes).ToVectorisedView(), - }); err != nil { - return 0, nil, err - } - break + if e.hdrIncluded { + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { + return 0, nil, err } - + } else { hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ Header: hdr, @@ -373,9 +366,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, }); err != nil { return 0, nil, err } - - default: - return 0, nil, tcpip.ErrUnknownProtocol } return int64(len(payloadBytes)), nil, nil @@ -400,11 +390,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { return tcpip.ErrInvalidEndpointState } - // We don't support IPv6 yet. - if len(addr.Addr) != header.IPv4AddressSize { - return tcpip.ErrInvalidEndpointState - } - nic := addr.NIC if e.bound { if e.BindNICID == 0 { @@ -470,14 +455,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - // Callers must provide an IPv4 address or no network address (for - // binding to a NIC, but not an address). - if len(addr.Addr) != 0 && len(addr.Addr) != 4 { - return tcpip.ErrInvalidEndpointState - } - // If a local address was specified, verify that it's valid. - if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + if len(addr.Addr) != 0 && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { return tcpip.ErrBadLocalAddress } @@ -527,11 +506,24 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // SetSockOpt implements tcpip.Endpoint.SetSockOpt. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { - return tcpip.ErrUnknownProtocolOption + switch opt.(type) { + case tcpip.SocketDetachFilterOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } } // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } return tcpip.ErrUnknownProtocolOption } @@ -541,9 +533,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &ss); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) } if v > ss.Max { v = ss.Max @@ -559,9 +551,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(e.TransProto, &rs); err != nil { - panic(fmt.Sprintf("s.TransportProtocolOption(%d, %+v) = %s", e.TransProto, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) } if v > rs.Max { v = rs.Max @@ -596,6 +588,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.KeepaliveEnabledOption: return false, nil + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -635,8 +633,15 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { e.rcvMu.Lock() - // Drop the packet if our buffer is currently full. - if e.rcvClosed { + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { e.rcvMu.Unlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() @@ -680,12 +685,22 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { }, } - headers := append(buffer.View(nil), pkt.NetworkHeader...) - headers = append(headers, pkt.TransportHeader...) - combinedVV := headers.ToVectorisedView() + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) + headers = append(headers, pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + } combinedVV.Append(pkt.Data) packet.data = combinedVV - packet.timestampNS = e.stack.NowNanoseconds() + packet.timestampNS = e.stack.Clock().NowNanoseconds() e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 6baeda8e4..e860ee484 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -49,6 +49,7 @@ go_library( "segment_heap.go", "segment_queue.go", "segment_state.go", + "segment_unsafe.go", "snd.go", "snd_state.go", "tcp_endpoint_list.go", @@ -86,6 +87,7 @@ go_test( "tcp_test.go", "tcp_timestamp_test.go", ], + shard_count = 10, deps = [ ":tcp", "//pkg/sync", diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 7679fe169..6e00e5526 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -27,7 +27,6 @@ import ( "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/seqnum" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/waiter" @@ -199,9 +198,8 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu } // createConnectingEndpoint creates a new endpoint in a connecting state, with -// the connection parameters given by the arguments. The endpoint is returned -// with n.mu held. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { // Create a new endpoint. netProto := l.netProto if netProto == 0 { @@ -227,22 +225,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - // Lock the endpoint before registering to ensure that no out of - // band changes are possible due to incoming packets etc till - // the endpoint is done initializing. - n.mu.Lock() - - // Register new endpoint so that packets are routed to it. - if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, ports.Flags{LoadBalanced: n.reusePort}, n.boundBindToDevice); err != nil { - n.mu.Unlock() - n.Close() - return nil, err - } - - n.isRegistered = true - n.registeredReusePort = n.reusePort - - return n, nil + return n } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -253,10 +236,12 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) - if err != nil { - return nil, err - } + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() ep.owner = owner // listenEP is nil when listenContext is used by tcp.Forwarder. @@ -264,18 +249,13 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head if l.listenEP != nil { l.listenEP.mu.Lock() if l.listenEP.EndpointState() != StateListen { + l.listenEP.mu.Unlock() // Ensure we release any registrations done by the newly // created endpoint. ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) return nil, tcpip.ErrConnectionAborted } l.addPendingEndpoint(ep) @@ -284,21 +264,44 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // to the newly created endpoint. l.listenEP.propagateInheritableOptionsLocked(ep) + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + deferAccept = l.listenEP.deferAccept l.listenEP.mu.Unlock() } + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + // Perform the 3-way handshake. h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) if err := h.execute(); err != nil { ep.mu.Unlock() ep.Close() - // Wake up any waiters. This is strictly not required normally - // as a socket that was never accepted can't really have any - // registered waiters except when stack.Wait() is called which - // waits for all registered endpoints to stop and expects an - // EventHUp. - ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + ep.notifyAborted() if l.listenEP != nil { l.removePendingEndpoint(ep) @@ -374,6 +377,43 @@ func (e *endpoint) deliverAccepted(n *endpoint) { // Precondition: e.mu and n.mu must be held. func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } // handleSynSegment is called in its own goroutine once the listening endpoint @@ -568,16 +608,34 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) - if err != nil { + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() return } - // Propagate any inheritable options from the listening endpoint - // to the newly created endpoint. - e.propagateInheritableOptionsLocked(n) + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + n.isRegistered = true // clear the tsOffset for the newly created // endpoint as the Timestamp was already diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 377643b82..1798510bc 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -490,6 +490,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } } // Wait for notification. @@ -509,9 +512,7 @@ func (h *handshake) execute() *tcpip.Error { // Initialize the resend timer. resendWaker := sleep.Waker{} timeOut := time.Duration(time.Second) - rt := time.AfterFunc(timeOut, func() { - resendWaker.Assert() - }) + rt := time.AfterFunc(timeOut, resendWaker.Assert) defer rt.Stop() // Set up the wakers. @@ -521,7 +522,7 @@ func (h *handshake) execute() *tcpip.Error { s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) defer s.Done() - var sackEnabled tcpip.StackSACKEnabled + var sackEnabled SACKEnabled if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { // If stack returned an error when checking for SACKEnabled // status then just default to switching off SACK negotiation. @@ -618,6 +619,9 @@ func (h *handshake) execute() *tcpip.Error { <-h.ep.undrain h.ep.mu.Lock() } + if n¬ifyError != 0 { + return h.ep.takeLastError() + } case wakerForNewSegment: if err := h.processSegments(); err != nil { @@ -1050,8 +1054,8 @@ func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { panic("current endpoint not removed from demuxer, enqueing segments to itself") } - if ep.(*endpoint).enqueueSegment(s) { - ep.(*endpoint).newSegmentWaker.Assert() + if ep := ep.(*endpoint); ep.enqueueSegment(s) { + ep.newSegmentWaker.Assert() } } @@ -1120,7 +1124,7 @@ func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { checkRequeue := true for i := 0; i < maxSegmentsPerWake; i++ { - if e.EndpointState() == StateClose || e.EndpointState() == StateError { + if e.EndpointState().closed() { return nil } s := e.segmentQueue.dequeue() @@ -1440,9 +1444,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ if e.EndpointState() == StateFinWait2 && e.closed { // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = time.AfterFunc(e.tcpLingerTimeout, func() { - closeWaker.Assert() - }) + closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } } @@ -1460,7 +1462,7 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return err } } - if e.EndpointState() != StateClose && e.EndpointState() != StateError { + if !e.EndpointState().closed() { // Only block the worker if the endpoint // is not in closed state or error state. close(e.drainDone) @@ -1526,7 +1528,12 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ } loop: - for e.EndpointState() != StateTimeWait && e.EndpointState() != StateClose && e.EndpointState() != StateError { + for { + switch e.EndpointState() { + case StateTimeWait, StateClose, StateError: + break loop + } + e.mu.Unlock() v, _ := s.Fetch(true) e.mu.Lock() diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 047704c80..98aecab9e 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -15,6 +15,8 @@ package tcp import ( + "encoding/binary" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -66,89 +68,68 @@ func (q *epQueue) empty() bool { // processor is responsible for processing packets queued to a tcp endpoint. type processor struct { epQ epQueue + sleeper sleep.Sleeper newEndpointWaker sleep.Waker closeWaker sleep.Waker - id int - wg sync.WaitGroup -} - -func newProcessor(id int) *processor { - p := &processor{ - id: id, - } - p.wg.Add(1) - go p.handleSegments() - return p } func (p *processor) close() { p.closeWaker.Assert() } -func (p *processor) wait() { - p.wg.Wait() -} - func (p *processor) queueEndpoint(ep *endpoint) { // Queue an endpoint for processing by the processor goroutine. p.epQ.enqueue(ep) p.newEndpointWaker.Assert() } -func (p *processor) handleSegments() { - const newEndpointWaker = 1 - const closeWaker = 2 - s := sleep.Sleeper{} - s.AddWaker(&p.newEndpointWaker, newEndpointWaker) - s.AddWaker(&p.closeWaker, closeWaker) - defer s.Done() +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + for { - id, ok := s.Fetch(true) - if ok && id == closeWaker { - p.wg.Done() - return + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break } - for ep := p.epQ.dequeue(); ep != nil; ep = p.epQ.dequeue() { + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } if ep.segmentQueue.empty() { continue } - // If socket has transitioned out of connected state - // then just let the worker handle the packet. + // If socket has transitioned out of connected state then just let the + // worker handle the packet. // - // NOTE: We read this outside of e.mu lock which means - // that by the time we get to handleSegments the - // endpoint may not be in ESTABLISHED. But this should - // be fine as all normal shutdown states are handled by - // handleSegments and if the endpoint moves to a - // CLOSED/ERROR state then handleSegments is a noop. - if ep.EndpointState() != StateEstablished { - ep.newSegmentWaker.Assert() - continue - } - - if !ep.mu.TryLock() { - ep.newSegmentWaker.Assert() - continue - } - // If the endpoint is in a connected state then we do - // direct delivery to ensure low latency and avoid - // scheduler interactions. - if err := ep.handleSegments(true /* fastPath */); err != nil || ep.EndpointState() == StateClose { - // Send any active resets if required. - if err != nil { + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) } - ep.notifyProtocolGoroutine(notifyTickleWorker) ep.mu.Unlock() - continue - } - - if !ep.segmentQueue.empty() { - p.epQ.enqueue(ep) + } else { + ep.newSegmentWaker.Assert() } - - ep.mu.Unlock() } } } @@ -159,31 +140,36 @@ func (p *processor) handleSegments() { // hash of the endpoint id to ensure that delivery for the same endpoint happens // in-order. type dispatcher struct { - processors []*processor + processors []processor seed uint32 -} - -func newDispatcher(nProcessors int) *dispatcher { - processors := []*processor{} - for i := 0; i < nProcessors; i++ { - processors = append(processors, newProcessor(i)) - } - return &dispatcher{ - processors: processors, - seed: generateRandUint32(), + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) } } func (d *dispatcher) close() { - for _, p := range d.processors { - p.close() + for i := range d.processors { + d.processors[i].close() } } func (d *dispatcher) wait() { - for _, p := range d.processors { - p.wait() - } + d.wg.Wait() } func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { @@ -231,20 +217,18 @@ func generateRandUint32() uint32 { if _, err := rand.Read(b); err != nil { panic(err) } - return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 + return binary.LittleEndian.Uint32(b) } func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { - payload := []byte{ - byte(id.LocalPort), - byte(id.LocalPort >> 8), - byte(id.RemotePort), - byte(id.RemotePort >> 8)} + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) h := jenkins.Sum32(d.seed) - h.Write(payload) + h.Write(payload[:]) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) - return d.processors[h.Sum32()%uint32(len(d.processors))] + return &d.processors[h.Sum32()%uint32(len(d.processors))] } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 10df2bcd5..0f7487963 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -396,7 +396,8 @@ type endpoint struct { mu sync.Mutex `state:"nosave"` ownedByUser uint32 - // state must be read/set using the EndpointState()/setEndpointState() methods. + // state must be read/set using the EndpointState()/setEndpointState() + // methods. state EndpointState `state:".(EndpointState)"` // origEndpointState is only used during a restore phase to save the @@ -405,8 +406,8 @@ type endpoint struct { origEndpointState EndpointState `state:"nosave"` isPortReserved bool `state:"manual"` - isRegistered bool - boundNICID tcpip.NICID `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID route stack.Route `state:"manual"` ttl uint8 v6only bool @@ -415,10 +416,14 @@ type endpoint struct { // disabling SO_BROADCAST, albeit as a NOOP. broadcast bool + // portFlags stores the current values of port related flags. + portFlags ports.Flags + // Values used to reserve a port or register a transport endpoint // (which ever happens first). boundBindToDevice tcpip.NICID boundPortFlags ports.Flags + boundDest tcpip.FullAddress // effectiveNetProtos contains the network protocols actually in use. In // most cases it will only contain "netProto", but in cases like IPv6 @@ -426,7 +431,7 @@ type endpoint struct { // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped // address). - effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + effectiveNetProtos []tcpip.NetworkProtocolNumber // workerRunning specifies if a worker goroutine is running. workerRunning bool @@ -462,13 +467,6 @@ type endpoint struct { // sack holds TCP SACK related information for this endpoint. sack SACKInfo - // reusePort is set to true if SO_REUSEPORT is enabled. - reusePort bool - - // registeredReusePort is set if the current endpoint registration was - // done with SO_REUSEPORT enabled. - registeredReusePort bool - // bindToDevice is set to the NIC on which to bind or disabled if 0. bindToDevice tcpip.NICID @@ -488,7 +486,6 @@ type endpoint struct { // The options below aren't implemented, but we remember the user // settings because applications expect to be able to set/query these // options. - reuseAddr bool // slowAck holds the negated state of quick ack. It is stubbed out and // does nothing. @@ -838,7 +835,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue rcvBufSize: DefaultReceiveBufferSize, sndBufSize: DefaultSendBufferSize, sndMTU: int(math.MaxInt32), - reuseAddr: true, keepalive: keepalive{ // Linux defaults. idle: 2 * time.Hour, @@ -851,12 +847,12 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue maxSynRetries: DefaultSynRetries, } - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { e.sndBufSize = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { e.rcvBufSize = rs.Default } @@ -871,7 +867,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue e.rcvAutoParams.disabled = !bool(mrb) } - var de tcpip.StackDelayEnabled + var de DelayEnabled if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { e.SetSockOptBool(tcpip.DelayOption, true) } @@ -1025,15 +1021,15 @@ func (e *endpoint) closeNoShutdownLocked() { // in Listen() when trying to register. if e.EndpointState() == StateListen && e.isPortReserved { if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} } // Mark endpoint as closed. @@ -1091,17 +1087,17 @@ func (e *endpoint) cleanupLocked() { e.workerCleanup = false if e.isRegistered { - e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.registeredReusePort}, e.boundBindToDevice) + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) e.isRegistered = false - e.registeredReusePort = false } if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) e.isPortReserved = false } e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} e.route.Release() e.stack.CompleteTransportEndpointCleanup(e) @@ -1213,9 +1209,27 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + err := e.lastError + e.lastError = nil + return err +} + // Read reads data from the endpoint. func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received // would cause the state to become StateError so we should allow the @@ -1225,7 +1239,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.HardError - e.UnlockUser() if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } @@ -1235,7 +1248,6 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, v, err := e.readLocked() e.rcvListMu.Unlock() - e.UnlockUser() if err == tcpip.ErrClosedForReceive { e.stats.ReadErrors.ReadClosed.Increment() @@ -1522,12 +1534,12 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { case tcpip.ReuseAddressOption: e.LockUser() - e.reuseAddr = v + e.portFlags.TupleOnly = v e.UnlockUser() case tcpip.ReusePortOption: e.LockUser() - e.reusePort = v + e.portFlags.LoadBalanced = v e.UnlockUser() case tcpip.V6OnlyOption: @@ -1585,10 +1597,17 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { e.UnlockUser() e.notifyProtocolGoroutine(notifyMSSChanged) + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min { v = rs.Min @@ -1638,7 +1657,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if v < ss.Min { v = ss.Min @@ -1678,7 +1697,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return tcpip.ErrInvalidOptionValue } } - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { if v < rs.Min/2 { v = rs.Min / 2 @@ -1781,6 +1800,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.deferAccept = time.Duration(v) e.UnlockUser() + case tcpip.SocketDetachFilterOption: + return nil + default: return nil } @@ -1831,14 +1853,14 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.ReuseAddressOption: e.LockUser() - v := e.reuseAddr + v := e.portFlags.TupleOnly e.UnlockUser() return v, nil case tcpip.ReusePortOption: e.LockUser() - v := e.reusePort + v := e.portFlags.LoadBalanced e.UnlockUser() return v, nil @@ -1892,6 +1914,11 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { v := header.TCPDefaultMSS return v, nil + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.ReceiveQueueSizeOption: return e.readyReceiveSize() @@ -1937,11 +1964,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { case tcpip.ErrorOption: - e.lastErrorMu.Lock() - err := e.lastError - e.lastError = nil - e.lastErrorMu.Unlock() - return err + return e.takeLastError() case *tcpip.BindToDeviceOption: e.LockUser() @@ -2091,8 +2114,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc } defer r.Release() - origID := e.ID - netProtos := []tcpip.NetworkProtocolNumber{netProto} e.ID.LocalAddress = r.LocalAddress e.ID.RemoteAddress = r.RemoteAddress @@ -2100,11 +2121,10 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if e.ID.LocalPort != 0 { // The endpoint is bound to a port, attempt to register it. - err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice) + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) if err != nil { return err } - e.registeredReusePort = e.reusePort } else { // The endpoint doesn't have a local port yet, so try to get // one. Make sure that it isn't one that will result in the same @@ -2128,40 +2148,33 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc if sameAddr && p == e.ID.RemotePort { return false, nil } - // reusePort is false below because connect cannot reuse a port even if - // reusePort was set. - if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.ID.LocalAddress, p, ports.Flags{LoadBalanced: false}, e.bindToDevice) { + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { return false, nil } id := e.ID id.LocalPort = p - switch e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, ports.Flags{LoadBalanced: e.reusePort}, e.bindToDevice) { - case nil: - // Port picking successful. Save the details of - // the selected port. - e.ID = id - e.boundBindToDevice = e.bindToDevice - e.registeredReusePort = e.reusePort - return true, nil - case tcpip.ErrPortInUse: - return false, nil - default: + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } return false, err } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil }); err != nil { return err } } - // Remove the port reservation. This can happen when Bind is called - // before Connect: in such a case we don't want to hold on to - // reservations anymore. - if e.isPortReserved { - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort, e.boundPortFlags, e.boundBindToDevice) - e.isPortReserved = false - } - e.isRegistered = true e.setEndpointState(StateConnecting) e.route = r.Clone() @@ -2340,13 +2353,12 @@ func (e *endpoint) listen(backlog int) *tcpip.Error { } // Register the endpoint. - if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, ports.Flags{LoadBalanced: e.reusePort}, e.boundBindToDevice); err != nil { + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { return err } e.isRegistered = true e.setEndpointState(StateListen) - e.registeredReusePort = e.reusePort // The channel may be non-nil when we're restoring the endpoint, and it // may be pre-populated with some previously accepted (but not Accepted) @@ -2433,16 +2445,13 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { } } - flags := ports.Flags{ - LoadBalanced: e.reusePort, - } - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, flags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return err } e.boundBindToDevice = e.bindToDevice - e.boundPortFlags = flags + e.boundPortFlags = e.portFlags e.isPortReserved = true e.effectiveNetProtos = netProtos e.ID.LocalPort = port @@ -2450,7 +2459,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { // Any failures beyond this point must remove the port registration. defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) e.isPortReserved = false e.effectiveNetProtos = nil e.ID.LocalPort = 0 @@ -2473,6 +2482,10 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { e.ID.LocalAddress = addr.Addr } + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + // Mark endpoint as bound. e.setEndpointState(StateBound) @@ -2537,6 +2550,18 @@ func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.C e.sndBufMu.Unlock() e.notifyProtocolGoroutine(notifyMTUChanged) + + case stack.ControlNoRoute: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNoRoute + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) + + case stack.ControlNetworkUnreachable: + e.lastErrorMu.Lock() + e.lastError = tcpip.ErrNetworkUnreachable + e.lastErrorMu.Unlock() + e.notifyProtocolGoroutine(notifyError) } } @@ -2609,7 +2634,7 @@ func (e *endpoint) receiveBufferSize() int { } func (e *endpoint) maxReceiveBufferSize() int { - var rs tcpip.StackReceiveBufferSizeOption + var rs ReceiveBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { // As a fallback return the hardcoded max buffer size. return MaxBufferSize @@ -2690,7 +2715,7 @@ func timeStampOffset() uint32 { // if the SYN options indicate that the SACK option was negotiated and the TCP // stack is configured to enable TCP SACK option. func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { - var v tcpip.StackSACKEnabled + var v SACKEnabled if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 0bebec2d1..abf1ac5c9 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -93,10 +93,6 @@ func (e *endpoint) beforeSave() { if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { panic("endpoint still has waiters upon save") } - - if e.EndpointState() != StateClose && !((e.EndpointState() == StateBound || e.EndpointState() == StateListen) == e.isPortReserved) { - panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") - } } // saveAcceptedChan is invoked by stateify. @@ -186,26 +182,33 @@ func (e *endpoint) Resume(s *stack.Stack) { epState := e.origEndpointState switch epState { case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: - var ss tcpip.StackSendBufferSizeOption + var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) } - if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { - panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) } } } bind := func() { - if len(e.BindAddr) == 0 { - e.BindAddr = e.ID.LocalAddress + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) } - addr := e.BindAddr - port := e.ID.LocalPort - if err := e.Bind(tcpip.FullAddress{Addr: addr, Port: port}); err != nil { - panic(fmt.Sprintf("endpoint binding [%v]:%d failed: %v", addr, port, err)) + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) } switch { @@ -277,17 +280,7 @@ func (e *endpoint) Resume(s *stack.Stack) { tcpip.AsyncLoading.Done() }() case epState == StateClose: - if e.isPortReserved { - tcpip.AsyncLoading.Add(1) - go func() { - connectedLoading.Wait() - listenLoading.Wait() - connectingLoading.Wait() - bind() - e.setEndpointState(StateClose) - tcpip.AsyncLoading.Done() - }() - } + e.isPortReserved = false e.state = StateClose e.stack.CompleteTransportEndpointCleanup(e) tcpip.DeleteDanglingEndpoint(e) diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 3cff55afa..b34e47bbd 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -76,6 +76,31 @@ const ( ccCubic = "cubic" ) +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + // syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The // value is protected by a mutex so that we can increment only when it's // guaranteed not to go above a threshold. @@ -137,8 +162,8 @@ type protocol struct { mu sync.RWMutex sackEnabled bool delayEnabled bool - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string moderateReceiveBuffer bool @@ -149,7 +174,7 @@ type protocol struct { maxRetries uint32 synRcvdCount synRcvdCounter synRetries uint8 - dispatcher *dispatcher + dispatcher dispatcher } // Number returns the tcp protocol number. @@ -249,19 +274,19 @@ func replyWithReset(s *segment, tos, ttl uint8) { // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { switch v := option.(type) { - case tcpip.StackSACKEnabled: + case SACKEnabled: p.mu.Lock() p.sackEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackDelayEnabled: + case DelayEnabled: p.mu.Lock() p.delayEnabled = bool(v) p.mu.Unlock() return nil - case tcpip.StackSendBufferSizeOption: + case SendBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -270,7 +295,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case tcpip.StackReceiveBufferSizeOption: + case ReceiveBufferSizeOption: if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { return tcpip.ErrInvalidOptionValue } @@ -363,25 +388,25 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { switch v := option.(type) { - case *tcpip.StackSACKEnabled: + case *SACKEnabled: p.mu.RLock() - *v = tcpip.StackSACKEnabled(p.sackEnabled) + *v = SACKEnabled(p.sackEnabled) p.mu.RUnlock() return nil - case *tcpip.StackDelayEnabled: + case *DelayEnabled: p.mu.RLock() - *v = tcpip.StackDelayEnabled(p.delayEnabled) + *v = DelayEnabled(p.delayEnabled) p.mu.RUnlock() return nil - case *tcpip.StackSendBufferSizeOption: + case *SendBufferSizeOption: p.mu.RLock() *v = p.sendBufferSize p.mu.RUnlock() return nil - case *tcpip.StackReceiveBufferSizeOption: + case *ReceiveBufferSizeOption: p.mu.RLock() *v = p.recvBufferSize p.mu.RUnlock() @@ -490,13 +515,13 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a TCP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{ + p := protocol{ + sendBufferSize: SendBufferSizeOption{ Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize, }, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{ + recvBufferSize: ReceiveBufferSizeOption{ Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize, @@ -506,10 +531,11 @@ func NewProtocol() stack.TransportProtocol { tcpLingerTimeout: DefaultTCPLingerTimeout, tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, - dispatcher: newDispatcher(runtime.GOMAXPROCS(0)), synRetries: DefaultSynRetries, minRTO: MinRTO, maxRTO: MaxRTO, maxRetries: MaxRetries, } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index dd89a292a..5e0bfe585 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -372,7 +372,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { // We only store the segment if it's within our buffer // size limit. if r.pendingBufUsed < r.pendingBufSize { - r.pendingBufUsed += s.logicalLen() + r.pendingBufUsed += seqnum.Size(s.segMemSize()) s.incRef() heap.Push(&r.pendingRcvdSegments, s) UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) @@ -406,7 +406,7 @@ func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { } heap.Pop(&r.pendingRcvdSegments) - r.pendingBufUsed -= s.logicalLen() + r.pendingBufUsed -= seqnum.Size(s.segMemSize()) s.decRef() } return false, nil diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 0280892a8..bb60dc29d 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -138,6 +138,12 @@ func (s *segment) logicalLen() seqnum.Size { return l } +// segMemSize is the amount of memory used to hold the segment data and +// the associated metadata. +func (s *segment) segMemSize() int { + return segSize + s.data.Size() +} + // parse populates the sequence & ack numbers, flags, and window fields of the // segment from the TCP header stored in the data. It then updates the view to // skip the header. diff --git a/pkg/tcpip/transport/tcp/segment_unsafe.go b/pkg/tcpip/transport/tcp/segment_unsafe.go new file mode 100644 index 000000000..0ab7b8f56 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_unsafe.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "unsafe" +) + +const ( + segSize = int(unsafe.Sizeof(segment{})) +) diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index acacb42e4..5862c32f2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -833,25 +833,6 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se panic("Netstack queues FIN segments without data.") } - segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - // If the entire segment cannot be accomodated in the receiver - // advertized window, skip splitting and sending of the segment. - // ref: net/ipv4/tcp_output.c::tcp_snd_wnd_test() - // - // Linux checks this for all segment transmits not triggered - // by a probe timer. On this condition, it defers the segment - // split and transmit to a short probe timer. - // ref: include/net/tcp.h::tcp_check_probe_timer() - // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() - // - // Instead of defining a new transmit timer, we attempt to split the - // segment right here if there are no pending segments. - // If there are pending segments, segment transmits are deferred - // to the retransmit timer handler. - if s.sndUna != s.sndNxt && !segEnd.LessThan(end) { - return false - } - if !seg.sequenceNumber.LessThan(end) { return false } @@ -861,14 +842,48 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se return false } - // The segment size limit is computed as a function of sender congestion - // window and MSS. When sender congestion window is > 1, this limit can - // be larger than MSS. Ensure that the currently available send space - // is not greater than minimum of this limit and MSS. + // If the whole segment or at least 1MSS sized segment cannot + // be accomodated in the receiver advertized window, skip + // splitting and sending of the segment. ref: + // net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered by + // a probe timer. On this condition, it defers the segment split + // and transmit to a short probe timer. + // + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split + // the segment right here if there are no pending segments. If + // there are pending segments, segment transmits are deferred to + // the retransmit timer handler. + if s.sndUna != s.sndNxt { + switch { + case available >= seg.data.Size(): + // OK to send, the whole segments fits in the + // receiver's advertised window. + case available >= s.maxPayloadSize: + // OK to send, at least 1 MSS sized segment fits + // in the receiver's advertised window. + default: + return false + } + } + + // The segment size limit is computed as a function of sender + // congestion window and MSS. When sender congestion window is > + // 1, this limit can be larger than MSS. Ensure that the + // currently available send space is not greater than minimum of + // this limit and MSS. if available > limit { available = limit } - if available > s.maxPayloadSize { + + // If GSO is not in use then cap available to + // maxPayloadSize. When GSO is in use the gVisor GSO logic or + // the host GSO logic will cap the segment to the correct size. + if s.ep.gso == nil && available > s.maxPayloadSize { available = s.maxPayloadSize } diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go index 812e503bc..99521f0c1 100644 --- a/pkg/tcpip/transport/tcp/tcp_sack_test.go +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -46,8 +46,8 @@ func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { t.Helper() - if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSACKEnabled(enable)); err != nil { - t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, StackSACKEnabled(%t) = %s", enable, err) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index aca6a7951..e67ec42b1 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3095,6 +3095,63 @@ func TestMaxRTO(t *testing.T) { } } +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + func TestFinImmediately(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() @@ -3879,6 +3936,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3888,6 +3948,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3898,6 +3961,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3910,6 +3976,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3920,6 +3989,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3932,6 +4004,9 @@ func TestReusePort(t *testing.T) { if err != nil { t.Fatalf("NewEndpoint failed; %s", err) } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %s", err) } @@ -3987,7 +4062,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default send buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ Min: 1, Default: tcp.DefaultSendBufferSize * 2, Max: tcp.DefaultSendBufferSize * 20}); err != nil { @@ -4004,7 +4079,7 @@ func TestDefaultBufferSizes(t *testing.T) { checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) // Change the default receive buffer size. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{ + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ Min: 1, Default: tcp.DefaultReceiveBufferSize * 3, Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { @@ -4035,11 +4110,11 @@ func TestMinMaxBufferSizes(t *testing.T) { defer ep.Close() // Change the min/max values for send/receive - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5678,7 +5753,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // the segment queue holding unprocessed packets is limited to 500. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5799,7 +5874,7 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // the segment queue holding unprocessed packets is limited to 300. const receiveBufferSize = 80 << 10 // 80KB. const maxReceiveBufferSize = receiveBufferSize * 10 - if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -5941,7 +6016,7 @@ func TestDelayEnabled(t *testing.T) { checkDelayOption(t, c, false, false) // Delay is disabled by default. for _, v := range []struct { - delayEnabled tcpip.StackDelayEnabled + delayEnabled tcp.DelayEnabled wantDelayOption bool }{ {delayEnabled: false, wantDelayOption: false}, @@ -5956,10 +6031,10 @@ func TestDelayEnabled(t *testing.T) { } } -func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcpip.StackDelayEnabled, wantDelayOption bool) { +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { t.Helper() - var gotDelayEnabled tcpip.StackDelayEnabled + var gotDelayEnabled tcp.DelayEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 9e262c272..37e7767d6 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -143,12 +143,14 @@ func New(t *testing.T, mtu uint32) *Context { TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, }) + const sendBufferSize = 1 << 20 // 1 MiB + const recvBufferSize = 1 << 20 // 1 MiB // Allow minimum send/receive buffer sizes to be 1 during tests. - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackSendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: sendBufferSize, Max: 10 * sendBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.StackReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: recvBufferSize, Max: 10 * recvBufferSize}); err != nil { t.Fatalf("SetTransportProtocolOption failed: %s", err) } @@ -202,7 +204,7 @@ func New(t *testing.T, mtu uint32) *Context { t: t, s: s, linkEP: ep, - WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + WindowScale: uint8(tcp.FindWndScale(recvBufferSize)), } } @@ -1091,7 +1093,7 @@ func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions // SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true // for the Stack in the context. func (c *Context) SACKEnabled() bool { - var v tcpip.StackSACKEnabled + var v tcp.SACKEnabled if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { // Stack doesn't support SACK. So just return. return false diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go index 12bc1b5b5..558b06df0 100644 --- a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -106,6 +106,11 @@ func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { return st } +// State returns the current state of the TCB. +func (t *TCB) State() Result { + return t.state +} + // IsAlive returns true as long as the connection is established(Alive) // or connecting state. func (t *TCB) IsAlive() bool { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f51988047..b7d735889 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -109,6 +109,7 @@ type endpoint struct { portFlags ports.Flags bindToDevice tcpip.NICID broadcast bool + noChecksum bool lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` @@ -190,13 +191,13 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue } // Override with stack defaults. - var ss tcpip.StackSendBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { e.sndBufSizeMax = ss.Default } - var rs tcpip.StackReceiveBufferSizeOption - if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { e.rcvBufSizeMax = rs.Default } @@ -231,7 +232,7 @@ func (e *endpoint) Close() { switch e.state { case StateBound, StateConnected: e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundBindToDevice = 0 e.boundPortFlags = ports.Flags{} } @@ -482,10 +483,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } - if to.Addr == header.IPv4Broadcast && !e.broadcast { - return 0, nil, tcpip.ErrBroadcastDisabled - } - dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -502,6 +499,10 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c resolve = route.Resolve } + if !e.broadcast && route.IsBroadcast() { + return 0, nil, tcpip.ErrBroadcastDisabled + } + if route.IsResolutionRequired() { if ch, err := resolve(nil); err != nil { if err == tcpip.ErrWouldBlock { @@ -529,7 +530,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -553,6 +554,11 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { e.multicastLoop = v e.mu.Unlock() + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + case tcpip.ReceiveTOSOption: e.mu.Lock() e.receiveTOS = v @@ -606,6 +612,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + case tcpip.MulticastTTLOption: e.mu.Lock() e.multicastTTL = uint8(v) @@ -629,9 +642,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.ReceiveBufferSizeOption: // Make sure the receive buffer size is within the min and max // allowed. - var rs tcpip.StackReceiveBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, rs, err)) + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) } if v < rs.Min { @@ -648,9 +661,9 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { case tcpip.SendBufferSizeOption: // Make sure the send buffer size is within the min and max // allowed. - var ss tcpip.StackSendBufferSizeOption - if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err != nil { - panic(fmt.Sprintf("e.stack.TransportProtocolOption(%d, %+v) = %s", ProtocolNumber, ss, err)) + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) } if v < ss.Min { @@ -803,6 +816,9 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Lock() e.bindToDevice = id e.mu.Unlock() + + case tcpip.SocketDetachFilterOption: + return nil } return nil } @@ -825,6 +841,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + case tcpip.ReceiveTOSOption: e.mu.RLock() v := e.receiveTOS @@ -894,6 +916,10 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { e.mu.RUnlock() return v, nil + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + case tcpip.MulticastTTLOption: e.mu.Lock() v := int(e.multicastTTL) @@ -959,7 +985,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -973,8 +999,12 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u Length: length, }) - // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -1047,7 +1077,7 @@ func (e *endpoint) Disconnect() *tcpip.Error { } else { if e.ID.LocalPort != 0 { // Release the ephemeral port. - e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } e.state = StateInitial @@ -1198,7 +1228,7 @@ func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { if e.ID.LocalPort == 0 { - port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice) + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) if err != nil { return id, e.bindToDevice, err } @@ -1208,7 +1238,7 @@ func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.Networ err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) if err != nil { - e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice) + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) e.boundPortFlags = ports.Flags{} } return id, e.bindToDevice, err @@ -1350,10 +1380,37 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - e.rcvMu.Lock() + // Never receive from a multicast address. + if header.IsV4MulticastAddress(id.RemoteAddress) || + header.IsV6MulticastAddress(id.RemoteAddress) { + e.stack.Stats().UDP.InvalidSourceAddress.Increment() + e.stack.Stats().IP.InvalidSourceAddressesReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } + e.stack.Stats().UDP.PacketsReceived.Increment() e.stats.PacketsReceived.Increment() + e.rcvMu.Lock() // Drop the packet if our buffer is currently full. if !e.rcvReady || e.rcvClosed { e.rcvMu.Unlock() @@ -1394,7 +1451,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } - packet.timestamp = e.stack.NowNanoseconds() + packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index fc93f93c0..0e7464e3a 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -21,7 +21,6 @@ package udp import ( - "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -50,9 +49,6 @@ const ( ) type protocol struct { - mu sync.RWMutex - sendBufferSize tcpip.StackSendBufferSizeOption - recvBufferSize tcpip.StackReceiveBufferSizeOption } // Number returns the udp protocol number. @@ -203,48 +199,12 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans // SetOption implements stack.TransportProtocol.SetOption. func (p *protocol) SetOption(option interface{}) *tcpip.Error { - switch v := option.(type) { - case tcpip.StackSendBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.sendBufferSize = v - p.mu.Unlock() - return nil - - case tcpip.StackReceiveBufferSizeOption: - if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { - return tcpip.ErrInvalidOptionValue - } - p.mu.Lock() - p.recvBufferSize = v - p.mu.Unlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Option implements stack.TransportProtocol.Option. func (p *protocol) Option(option interface{}) *tcpip.Error { - switch v := option.(type) { - case *tcpip.StackSendBufferSizeOption: - p.mu.RLock() - *v = p.sendBufferSize - p.mu.RUnlock() - return nil - - case *tcpip.StackReceiveBufferSizeOption: - p.mu.RLock() - *v = p.recvBufferSize - p.mu.RUnlock() - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } + return tcpip.ErrUnknownProtocolOption } // Close implements stack.TransportProtocol.Close. @@ -267,8 +227,5 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) bool { // NewProtocol returns a UDP transport protocol. func NewProtocol() stack.TransportProtocol { - return &protocol{ - sendBufferSize: tcpip.StackSendBufferSizeOption{Min: MinBufferSize, Default: DefaultSendBufferSize, Max: MaxBufferSize}, - recvBufferSize: tcpip.StackReceiveBufferSizeOption{Min: MinBufferSize, Default: DefaultReceiveBufferSize, Max: MaxBufferSize}, - } + return &protocol{} } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 313a3f117..66e8911c8 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -83,16 +83,18 @@ type header4Tuple struct { type testFlow int const ( - unicastV4 testFlow = iota // V4 unicast on a V4 socket - unicastV4in6 // V4-mapped unicast on a V6-dual socket - unicastV6 // V6 unicast on a V6 socket - unicastV6Only // V6 unicast on a V6-only socket - multicastV4 // V4 multicast on a V4 socket - multicastV4in6 // V4-mapped multicast on a V6-dual socket - multicastV6 // V6 multicast on a V6 socket - multicastV6Only // V6 multicast on a V6-only socket - broadcast // V4 broadcast on a V4 socket - broadcastIn6 // V4-mapped broadcast on a V6-dual socket + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket + reverseMulticast4 // V4 multicast src. Must fail. + reverseMulticast6 // V6 multicast src. Must fail. ) func (flow testFlow) String() string { @@ -117,6 +119,10 @@ func (flow testFlow) String() string { return "broadcast" case broadcastIn6: return "broadcastIn6" + case reverseMulticast4: + return "reverseMulticast4" + case reverseMulticast6: + return "reverseMulticast6" default: return "unknown" } @@ -168,6 +174,9 @@ func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { h.dstAddr.Addr = multicastV6Addr } } + if flow.isReverseMulticast() { + h.srcAddr.Addr = flow.getMcastAddr() + } return h } @@ -199,9 +208,9 @@ func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { // endpoint for this flow. func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { switch flow { - case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6, reverseMulticast6: return ipv6.ProtocolNumber - case unicastV4, multicastV4, broadcast: + case unicastV4, multicastV4, broadcast, reverseMulticast4: return ipv4.ProtocolNumber default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -224,7 +233,7 @@ func (flow testFlow) isV6Only() bool { switch flow { case unicastV6Only, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -235,7 +244,7 @@ func (flow testFlow) isMulticast() bool { switch flow { case multicastV4, multicastV4in6, multicastV6, multicastV6Only: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -246,7 +255,7 @@ func (flow testFlow) isBroadcast() bool { switch flow { case broadcast, broadcastIn6: return true - case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) @@ -257,13 +266,22 @@ func (flow testFlow) isMapped() bool { switch flow { case unicastV4in6, multicastV4in6, broadcastIn6: return true - case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast, reverseMulticast4, reverseMulticast6: return false default: panic(fmt.Sprintf("invalid testFlow given: %d", flow)) } } +func (flow testFlow) isReverseMulticast() bool { + switch flow { + case reverseMulticast4, reverseMulticast6: + return true + default: + return false + } +} + type testContext struct { t *testing.T linkEP *channel.Endpoint @@ -292,15 +310,15 @@ func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Optio wep = sniffer.New(ep) } if err := s.CreateNIC(1, wep); err != nil { - t.Fatalf("CreateNIC failed: %v", err) + t.Fatalf("CreateNIC failed: %s", err) } if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { - t.Fatalf("AddAddress failed: %v", err) + t.Fatalf("AddAddress failed: %s", err) } s.SetRouteTable([]tcpip.Route{ @@ -391,17 +409,21 @@ func (c *testContext) injectPacket(flow testFlow, payload []byte) { h := flow.header4Tuple(incoming) if flow.isV4() { - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } else { - c.injectV6Packet(payload, &h, true /* valid */) + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) } } -// injectV6Packet creates a V6 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -420,16 +442,10 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool // Initialize the UDP header. u := header.UDP(buf[header.IPv6MinimumSize:]) - l := uint16(header.UDPMinimumSize + len(payload)) - if !valid { - // Change the UDP payload length to corrupt the header - // as requested by the caller. - l++ - } u.Encode(&header.UDPFields{ SrcPort: h.srcAddr.Port, DstPort: h.dstAddr.Port, - Length: l, + Length: uint16(header.UDPMinimumSize + len(payload)), }) // Calculate the UDP pseudo-header checksum. @@ -439,17 +455,12 @@ func (c *testContext) injectV6Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } -// injectV4Packet creates a V4 test packet with the given payload and header -// values, and injects it into the link endpoint. valid indicates if the -// caller intends to inject a packet with a valid or an invalid UDP header. -// We can invalidate the header by corrupting the UDP payload length. -func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool) { +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { // Allocate a buffer for data and headers. buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) payloadStart := len(buf) - len(payload) @@ -483,11 +494,7 @@ func (c *testContext) injectV4Packet(payload []byte, h *header4Tuple, valid bool xsum = header.Checksum(payload, xsum) u.SetChecksum(^u.CalculateChecksum(xsum)) - // Inject packet. - - c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ - Data: buf.ToVectorisedView(), - }) + return buf } func newPayload() []byte { @@ -509,7 +516,7 @@ func TestBindToDeviceOption(t *testing.T) { ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) if err != nil { - t.Fatalf("NewEndpoint failed; %v", err) + t.Fatalf("NewEndpoint failed; %s", err) } defer ep.Close() @@ -643,7 +650,7 @@ func TestBindEphemeralPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } } @@ -654,19 +661,19 @@ func TestBindReservedPort(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } addr, err := c.ep.GetLocalAddress() if err != nil { - t.Fatalf("GetLocalAddress failed: %v", err) + t.Fatalf("GetLocalAddress failed: %s", err) } // We can't bind the address reserved by the connected endpoint above. { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { @@ -677,7 +684,7 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() // We can't bind ipv4-any on the port reserved by the connected endpoint @@ -687,7 +694,7 @@ func TestBindReservedPort(t *testing.T) { } // We can bind an ipv4 address on this port, though. if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() @@ -697,11 +704,11 @@ func TestBindReservedPort(t *testing.T) { func() { ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) if err != nil { - t.Fatalf("NewEndpoint failed: %v", err) + t.Fatalf("NewEndpoint failed: %s", err) } defer ep.Close() if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { - t.Fatalf("ep.Bind(...) failed: %v", err) + t.Fatalf("ep.Bind(...) failed: %s", err) } }() } @@ -714,7 +721,7 @@ func TestV4ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -729,7 +736,7 @@ func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { // Bind to v4 mapped wildcard. if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -744,7 +751,7 @@ func TestV4ReadOnBoundToV4Mapped(t *testing.T) { // Bind to local address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -759,7 +766,7 @@ func TestV6ReadOnV6(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -796,7 +803,10 @@ func TestV4ReadSelfSource(t *testing.T) { h := unicastV4.header4Tuple(incoming) h.srcAddr = h.dstAddr - c.injectV4Packet(payload, &h, true /* valid */) + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) @@ -817,7 +827,7 @@ func TestV4ReadOnV4(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Test acceptance. @@ -880,6 +890,60 @@ func TestV4ReadOnBoundToBroadcast(t *testing.T) { } } +// TestReadFromMulticast checks that an endpoint will NOT receive a packet +// that was sent with multicast SOURCE address. +func TestReadFromMulticast(t *testing.T) { + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + testFailingRead(c, flow, false /* expectReadError */) + }) + } +} + +// TestReadFromMulticaststats checks that a discarded packet +// that that was sent with multicast SOURCE address increments +// the correct counters and that a regular packet does not. +func TestReadFromMulticastStats(t *testing.T) { + t.Helper() + for _, flow := range []testFlow{reverseMulticast4, reverseMulticast6, unicastV4} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + c.injectPacket(flow, payload) + + var want uint64 = 0 + if flow.isReverseMulticast() { + want = 1 + } + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != want { + t.Errorf("got stats.IP.InvalidSourceAddressesReceived.Value() = %d, want = %d", got, want) + } + if got := c.s.Stats().UDP.InvalidSourceAddress.Value(); got != want { + t.Errorf("got stats.UDP.InvalidSourceAddress.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } + }) + } +} + // TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY // and receive broadcast and unicast data. func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { @@ -955,7 +1019,7 @@ func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ... payload := buffer.View(newPayload()) n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) if err != nil { - c.t.Fatalf("Write failed: %v", err) + c.t.Fatalf("Write failed: %s", err) } if n != int64(len(payload)) { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) @@ -1005,7 +1069,7 @@ func TestDualWriteBoundToWildcard(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } p := testDualWrite(c) @@ -1022,7 +1086,7 @@ func TestDualWriteConnectedToV6(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV6) @@ -1043,7 +1107,7 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testWrite(c, unicastV4in6) @@ -1070,7 +1134,7 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { // Bind to v4 mapped address. if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } // Write to v6 address. @@ -1085,7 +1149,7 @@ func TestV6WriteOnConnected(t *testing.T) { // Connect to v6 address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV6) @@ -1099,7 +1163,7 @@ func TestV4WriteOnConnected(t *testing.T) { // Connect to v4 mapped address. if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } testWriteWithoutDestination(c, unicastV4) @@ -1234,7 +1298,7 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } testRead(c, unicastV4) @@ -1259,6 +1323,30 @@ func TestWriteIncrementsPacketsSent(t *testing.T) { } } +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1506,12 +1594,12 @@ func TestMulticastInterfaceOption(t *testing.T) { Port: stackPort, } if err := c.ep.Connect(addr); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } } if err := c.ep.SetSockOpt(ifoptSet); err != nil { - c.t.Fatalf("SetSockOpt failed: %v", err) + c.t.Fatalf("SetSockOpt failed: %s", err) } // Verify multicast interface addr and NIC were set correctly. @@ -1519,7 +1607,7 @@ func TestMulticastInterfaceOption(t *testing.T) { ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} var ifoptGot tcpip.MulticastInterfaceOption if err := c.ep.GetSockOpt(&ifoptGot); err != nil { - c.t.Fatalf("GetSockOpt failed: %v", err) + c.t.Fatalf("GetSockOpt failed: %s", err) } if ifoptGot != ifoptWant { c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) @@ -1691,7 +1779,7 @@ func TestV6UnknownDestination(t *testing.T) { } // TestIncrementMalformedPacketsReceived verifies if the malformed received -// global and endpoint stats get incremented. +// global and endpoint stats are incremented. func TestIncrementMalformedPacketsReceived(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() @@ -1699,20 +1787,27 @@ func TestIncrementMalformedPacketsReceived(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } payload := newPayload() - c.t.Helper() h := unicastV6.header4Tuple(incoming) - c.injectV6Packet(payload, &h, false /* !valid */) + buf := c.buildV6Packet(payload, &h) - var want uint64 = 1 + // Invalidate the UDP header length field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %v, want = %v", got, want) + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) } if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { - t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %v, want = %v", got, want) + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) } } @@ -1728,7 +1823,6 @@ func TestShortHeader(t *testing.T) { c.t.Fatalf("Bind failed: %s", err) } - c.t.Helper() h := unicastV6.header4Tuple(incoming) // Allocate a buffer for an IPv6 and too-short UDP header. @@ -1768,6 +1862,199 @@ func TestShortHeader(t *testing.T) { } } +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header checksum field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + // TestShutdownRead verifies endpoint read shutdown and error // stats increment on packet receive. func TestShutdownRead(t *testing.T) { @@ -1778,15 +2065,15 @@ func TestShutdownRead(t *testing.T) { // Bind to wildcard. if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { - c.t.Fatalf("Bind failed: %v", err) + c.t.Fatalf("Bind failed: %s", err) } if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingRead(c, unicastV6, true /* expectReadError */) @@ -1809,11 +2096,11 @@ func TestShutdownWrite(t *testing.T) { c.createEndpoint(ipv6.ProtocolNumber) if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { - c.t.Fatalf("Connect failed: %v", err) + c.t.Fatalf("Connect failed: %s", err) } if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { - t.Fatalf("Shutdown failed: %v", err) + t.Fatalf("Shutdown failed: %s", err) } testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) @@ -1855,3 +2142,192 @@ func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEn c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) } } + +func TestOutgoingSubnetBroadcast(t *testing.T) { + const nicID1 = 1 + + ipv4Addr := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 24, + } + ipv4Subnet := ipv4Addr.Subnet() + ipv4SubnetBcast := ipv4Subnet.Broadcast() + ipv4Gateway := tcpip.Address("\xc0\xa8\x01\x01") + ipv4AddrPrefix31 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 31, + } + ipv4Subnet31 := ipv4AddrPrefix31.Subnet() + ipv4Subnet31Bcast := ipv4Subnet31.Broadcast() + ipv4AddrPrefix32 := tcpip.AddressWithPrefix{ + Address: "\xc0\xa8\x01\x3a", + PrefixLen: 32, + } + ipv4Subnet32 := ipv4AddrPrefix32.Subnet() + ipv4Subnet32Bcast := ipv4Subnet32.Broadcast() + ipv6Addr := tcpip.AddressWithPrefix{ + Address: "\x20\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + PrefixLen: 64, + } + ipv6Subnet := ipv6Addr.Subnet() + ipv6SubnetBcast := ipv6Subnet.Broadcast() + remNetAddr := tcpip.AddressWithPrefix{ + Address: "\x64\x0a\x7b\x18", + PrefixLen: 24, + } + remNetSubnet := remNetAddr.Subnet() + remNetSubnetBcast := remNetSubnet.Broadcast() + + tests := []struct { + name string + nicAddr tcpip.ProtocolAddress + routes []tcpip.Route + remoteAddr tcpip.Address + requiresBroadcastOpt bool + }{ + { + name: "IPv4 Broadcast to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv4SubnetBcast, + requiresBroadcastOpt: true, + }, + { + name: "IPv4 Broadcast to local /31 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix31, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet31, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet31Bcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to local /32 subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4AddrPrefix32, + }, + routes: []tcpip.Route{ + { + Destination: ipv4Subnet32, + NIC: nicID1, + }, + }, + remoteAddr: ipv4Subnet32Bcast, + requiresBroadcastOpt: false, + }, + // IPv6 has no notion of a broadcast. + { + name: "IPv6 'Broadcast' to local subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv6ProtocolNumber, + AddressWithPrefix: ipv6Addr, + }, + routes: []tcpip.Route{ + { + Destination: ipv6Subnet, + NIC: nicID1, + }, + }, + remoteAddr: ipv6SubnetBcast, + requiresBroadcastOpt: false, + }, + { + name: "IPv4 Broadcast to remote subnet", + nicAddr: tcpip.ProtocolAddress{ + Protocol: header.IPv4ProtocolNumber, + AddressWithPrefix: ipv4Addr, + }, + routes: []tcpip.Route{ + { + Destination: remNetSubnet, + Gateway: ipv4Gateway, + NIC: nicID1, + }, + }, + remoteAddr: remNetSubnetBcast, + requiresBroadcastOpt: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) + e := channel.New(0, defaultMTU, "") + if err := s.CreateNIC(nicID1, e); err != nil { + t.Fatalf("CreateNIC(%d, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, test.nicAddr); err != nil { + t.Fatalf("AddProtocolAddress(%d, %+v): %s", nicID1, test.nicAddr, err) + } + + s.SetRouteTable(test.routes) + + var netProto tcpip.NetworkProtocolNumber + switch l := len(test.remoteAddr); l { + case header.IPv4AddressSize: + netProto = header.IPv4ProtocolNumber + case header.IPv6AddressSize: + netProto = header.IPv6ProtocolNumber + default: + t.Fatalf("got unexpected address length = %d bytes", l) + } + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(udp.ProtocolNumber, netProto, &wq) + if err != nil { + t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netProto, err) + } + defer ep.Close() + + data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + to := tcpip.FullAddress{ + Addr: test.remoteAddr, + Port: 80, + } + opts := tcpip.WriteOptions{To: &to} + expectedErrWithoutBcastOpt := tcpip.ErrBroadcastDisabled + if !test.requiresBroadcastOpt { + expectedErrWithoutBcastOpt = nil + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, true): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != nil { + t.Fatalf("got ep.Write(_, _) = (%d, _, %s), want = (_, _, nil)", n, err) + } + + if err := ep.SetSockOptBool(tcpip.BroadcastOption, false); err != nil { + t.Fatalf("got SetSockOptBool(BroadcastOption, false): %s", err) + } + + if n, _, err := ep.Write(data, opts); err != expectedErrWithoutBcastOpt { + t.Fatalf("got ep.Write(_, _) = (%d, _, %v), want = (_, _, %v)", n, err, expectedErrWithoutBcastOpt) + } + }) + } +} |