diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/arp/arp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/icmp.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/icmp.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 10 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/accept.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/connect.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 35 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/protocol.go | 2 | ||||
-rwxr-xr-x | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 37 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/protocol.go | 4 | ||||
-rwxr-xr-x | pkg/tcpip/transport/udp/udp_state_autogen.go | 2 |
19 files changed, 143 insertions, 40 deletions
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 26cf1c528..922181ac0 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -79,7 +79,7 @@ func (e *endpoint) MaxHeaderLength() uint16 { func (e *endpoint) Close() {} -func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, tcpip.TransportProtocolNumber, uint8, stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(*stack.Route, *stack.GSO, buffer.Prependable, buffer.VectorisedView, stack.NetworkHeaderParams, stack.PacketLooping) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index c1cf6c222..50b363dc4 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -95,7 +95,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V pkt.SetChecksum(0) pkt.SetChecksum(^header.Checksum(pkt, header.ChecksumVV(vv, 0))) sent := stats.ICMP.V4PacketsSent - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index fb6358fbb..df1a08113 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -199,21 +199,22 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, hdr buff } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) length := uint16(hdr.UsedLength() + payload.Size()) id := uint32(0) if length > header.IPv4MaximumHeaderSize+8 { // Packets of 68 bytes or less are required by RFC 791 to not be // fragmented, so we only assign ids to larger packets. - id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, protocol, e.protocol.hashIV)%buckets], 1) + id = atomic.AddUint32(&e.protocol.ids[hashRoute(r, params.Protocol, e.protocol.hashIV)%buckets], 1) } ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, TotalLength: length, ID: uint16(id), - TTL: ttl, - Protocol: uint8(protocol), + TTL: params.TTL, + TOS: params.TOS, + Protocol: uint8(params.Protocol), SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 71c398027..b5df85455 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -154,7 +154,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V r.LocalAddress = targetAddr pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, buffer.VectorisedView{}, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } @@ -185,7 +185,7 @@ func (e *endpoint) handleICMP(r *stack.Route, netHeader buffer.View, vv buffer.V copy(pkt, h) pkt.SetType(header.ICMPv6EchoReply) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, vv)) - if err := r.WritePacket(nil /* gso */, hdr, vv, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */); err != nil { + if err := r.WritePacket(nil /* gso */, hdr, vv, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { sent.Dropped.Increment() return } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 85c070c43..cd1e34085 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -98,13 +98,14 @@ func (e *endpoint) GSOMaxSize() uint32 { } // WritePacket writes a packet to the given destination address and protocol. -func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop stack.PacketLooping) *tcpip.Error { +func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params stack.NetworkHeaderParams, loop stack.PacketLooping) *tcpip.Error { length := uint16(hdr.UsedLength() + payload.Size()) ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ PayloadLength: length, - NextHeader: uint8(protocol), - HopLimit: ttl, + NextHeader: uint8(params.Protocol), + HopLimit: params.TTL, + TrafficClass: params.TOS, SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 80101d4bb..9d6157f22 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -146,6 +146,19 @@ const ( PacketLoop ) +// NetworkHeaderParams are the header parameters given as input by the +// transport endpoint to the network. +type NetworkHeaderParams struct { + // Protocol refers to the transport protocol number. + Protocol tcpip.TransportProtocolNumber + + // TTL refers to Time To Live field of the IP-header. + TTL uint8 + + // TOS refers to TypeOfService or TrafficClass field of the IP-header. + TOS uint8 +} + // NetworkEndpoint is the interface that needs to be implemented by endpoints // of network layer protocols (e.g., ipv4, ipv6). type NetworkEndpoint interface { @@ -170,7 +183,7 @@ type NetworkEndpoint interface { // WritePacket writes a packet to the given destination address and // protocol. - WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, loop PacketLooping) *tcpip.Error + WritePacket(r *Route, gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams, loop PacketLooping) *tcpip.Error // WriteHeaderIncludedPacket writes a packet that includes a network // header to the given destination address. diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 22db619a1..e72373964 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -154,16 +154,12 @@ func (r *Route) IsResolutionRequired() bool { } // WritePacket writes the packet through the given route. -func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8, useDefaultTTL bool) *tcpip.Error { +func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, params NetworkHeaderParams) *tcpip.Error { if !r.ref.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - if useDefaultTTL { - ttl = r.DefaultTTL() - } - - err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) + err := r.ref.ep.WritePacket(r, gso, hdr, payload, params, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() } else { diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7d73389cc..f67975525 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -43,6 +43,9 @@ const ( resolutionTimeout = 1 * time.Second // resolutionAttempts is set to the same ARP retries used in Linux. resolutionAttempts = 3 + + // DefaultTOS is the default type of service value for network endpoints. + DefaultTOS = 0 ) type transportProtocolState struct { @@ -394,7 +397,7 @@ type Stack struct { // portSeed is a one-time random value initialized at stack startup // and is used to seed the TCP port picking on active connections // - // TODO(gvisor.dev/issues/940): S/R this field. + // TODO(gvisor.dev/issue/940): S/R this field. portSeed uint32 } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 444ac1a5b..9d3752032 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -624,6 +624,14 @@ type BroadcastOption int // a default TTL. type DefaultTTLOption uint8 +// IPv4TOSOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv4 packets from the endpoint. +type IPv4TOSOption uint8 + +// IPv6TrafficClassOption is used by SetSockOpt/GetSockOpt to specify TOS +// for all subsequent outgoing IPv6 packets from the endpoint. +type IPv6TrafficClassOption uint8 + // Route is a row in the routing table. It specifies through which NIC (and // gateway) sets of packets should be routed. A row is considered viable if the // masked target address matches the destination address in the row. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d0cfdcda1..3187b336b 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -422,7 +422,10 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { @@ -445,7 +448,10 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Err icmpv6.SetChecksum(0) icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) - return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */) + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}) } func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 4f5c286cf..b4c660859 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -350,7 +350,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, break } hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) - if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), e.TransProto, 0, true /* useDefaultTTL */); err != nil { + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}); err != nil { return 0, nil, err } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 8f5572195..844959fa0 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -441,7 +441,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TSEcr: opts.TSVal, MSS: uint16(mss), } - e.sendSynTCP(&s.route, s.id, e.ttl, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.sendSynTCP(&s.route, s.id, e.ttl, e.sendTOS, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index cb8cfd619..5ea036bea 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -255,7 +255,7 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { if ttl == 0 { ttl = s.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -299,7 +299,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&s.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) return nil } @@ -468,7 +468,8 @@ func (h *handshake) execute() *tcpip.Error { synOpts.WS = -1 } } - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { switch index, _ := s.Fetch(true); index { case wakerForResend: @@ -477,7 +478,7 @@ func (h *handshake) execute() *tcpip.Error { return tcpip.ErrTimeout } rt.Reset(timeOut) - h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + h.ep.sendSynTCP(&h.ep.route, h.ep.ID, h.ep.ttl, h.ep.sendTOS, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) case wakerForNotification: n := h.ep.fetchNotifications() @@ -587,17 +588,18 @@ func makeSynOptions(opts header.TCPSynOptions) []byte { return options[:offset] } -func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) { +func (e *endpoint) sendSynTCP(r *stack.Route, id stack.TransportEndpointID, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { options := makeSynOptions(opts) // We ignore SYN send errors and let the callers re-attempt send. - if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, flags, seq, ack, rcvWnd, options, nil); err != nil { + if err := e.sendTCP(r, id, buffer.VectorisedView{}, ttl, tos, flags, seq, ack, rcvWnd, options, nil); err != nil { e.stats.SendErrors.SynSendToNetworkFailed.Increment() } putOptions(options) + return nil } -func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { - if err := sendTCP(r, id, data, ttl, flags, seq, ack, rcvWnd, opts, gso); err != nil { +func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + if err := sendTCP(r, id, data, ttl, tos, flags, seq, ack, rcvWnd, opts, gso); err != nil { e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() return err } @@ -607,7 +609,7 @@ func (e *endpoint) sendTCP(r *stack.Route, id stack.TransportEndpointID, data bu // sendTCP sends a TCP segment with the provided options via the provided // network endpoint and under the provided identity. -func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl, tos uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { optLen := len(opts) // Allocate a buffer for the TCP header. hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) @@ -643,7 +645,10 @@ func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.Vectorise tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } - if err := r.WritePacket(gso, hdr, data, ProtocolNumber, ttl, ttl == 0 /* useDefaultTTL */); err != nil { + if ttl == 0 { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { r.Stats().TCP.SegmentSendErrors.Increment() return err } @@ -700,7 +705,7 @@ func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqn sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) - err := e.sendTCP(&e.route, e.ID, data, e.ttl, flags, seq, ack, rcvWnd, options, e.gso) + err := e.sendTCP(&e.route, e.ID, data, e.ttl, e.sendTOS, flags, seq, ack, rcvWnd, options, e.gso) putOptions(options) return err } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 090a8eb24..a1b784b49 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -494,6 +494,10 @@ type endpoint struct { // amss is the advertised MSS to the peer by this endpoint. amss uint16 + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + gso *stack.GSO // TODO(b/142022063): Add ability to save and restore per endpoint stats. @@ -1136,6 +1140,8 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 switch v := opt.(type) { case tcpip.DelayOption: if v == 0 { @@ -1296,6 +1302,23 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { // Linux returns ENOENT when an invalid congestion // control algorithm is specified. return tcpip.ErrNoSuchFile + + case tcpip.IPv4TOSOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.mu.Unlock() + return nil + default: return nil } @@ -1495,6 +1518,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index d5d8ab96a..db40785d3 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -153,7 +153,7 @@ func replyWithReset(s *segment) { ack := s.sequenceNumber.Add(s.logicalLen()) - sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), stack.DefaultTOS, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0 /* rcvWnd */, nil /* options */, nil /* gso */) } // SetOption implements TransportProtocol.SetOption. diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 53c6cdb24..57eb0ad58 100755 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -141,6 +141,7 @@ func (x *endpoint) save(m state.Map) { m.Save("snd", &x.snd) m.Save("connectingAddress", &x.connectingAddress) m.Save("amss", &x.amss) + m.Save("sendTOS", &x.sendTOS) m.Save("gso", &x.gso) } @@ -189,6 +190,7 @@ func (x *endpoint) load(m state.Map) { m.LoadWait("snd", &x.snd) m.Load("connectingAddress", &x.connectingAddress) m.Load("amss", &x.amss) + m.Load("sendTOS", &x.sendTOS) m.Load("gso", &x.gso) m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) m.LoadValue("state", new(EndpointState), func(y interface{}) { x.loadState(y.(EndpointState)) }) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 2ad7978c2..6e87245b7 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -106,6 +106,10 @@ type endpoint struct { bindToDevice tcpip.NICID broadcast bool + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -429,7 +433,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c useDefaultTTL = false } - if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL); err != nil { + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS); err != nil { return 0, nil, err } return int64(len(v)), nil, nil @@ -628,6 +632,18 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + return nil } return nil } @@ -748,6 +764,18 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.IPv4TOSOption: + e.mu.RLock() + *o = tcpip.IPv4TOSOption(e.sendTOS) + e.mu.RUnlock() + return nil + + case *tcpip.IPv6TrafficClassOption: + e.mu.RLock() + *o = tcpip.IPv6TrafficClassOption(e.sendTOS) + e.mu.RUnlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } @@ -755,7 +783,7 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { // sendUDP sends a UDP segment via the provided network endpoint and under the // provided identity. -func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool) *tcpip.Error { +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8) *tcpip.Error { // Allocate a buffer for the UDP header. hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) @@ -778,7 +806,10 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u udp.SetChecksum(^udp.CalculateChecksum(xsum)) } - if err := r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl, useDefaultTTL); err != nil { + if useDefaultTTL { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(nil /* gso */, hdr, data, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: ttl, TOS: tos}); err != nil { r.Stats().UDP.PacketSendErrors.Increment() return err } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index 0b24bbc85..de026880f 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -130,7 +130,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv4DstUnreachable) pkt.SetCode(header.ICMPv4PortUnreachable) pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv4ProtocolNumber, 0, true /* useDefaultTTL */) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) case header.IPv6AddressSize: if !r.Stack().AllowICMPMessage() { @@ -164,7 +164,7 @@ func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.Trans pkt.SetType(header.ICMPv6DstUnreachable) pkt.SetCode(header.ICMPv6PortUnreachable) pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) - r.WritePacket(nil /* gso */, hdr, payload, header.ICMPv6ProtocolNumber, 0, true /* useDefaultTTL */) + r.WritePacket(nil /* gso */, hdr, payload, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}) } return true } diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index ca3fccc4b..bd811ed5c 100755 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -47,6 +47,7 @@ func (x *endpoint) save(m state.Map) { m.Save("reusePort", &x.reusePort) m.Save("bindToDevice", &x.bindToDevice) m.Save("broadcast", &x.broadcast) + m.Save("sendTOS", &x.sendTOS) m.Save("shutdownFlags", &x.shutdownFlags) m.Save("multicastMemberships", &x.multicastMemberships) m.Save("effectiveNetProtos", &x.effectiveNetProtos) @@ -71,6 +72,7 @@ func (x *endpoint) load(m state.Map) { m.Load("reusePort", &x.reusePort) m.Load("bindToDevice", &x.bindToDevice) m.Load("broadcast", &x.broadcast) + m.Load("sendTOS", &x.sendTOS) m.Load("shutdownFlags", &x.shutdownFlags) m.Load("multicastMemberships", &x.multicastMemberships) m.Load("effectiveNetProtos", &x.effectiveNetProtos) |