diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/transport/internal/network/endpoint.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/transport/internal/network/network_state_autogen.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_state_autogen.go | 27 |
4 files changed, 55 insertions, 32 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go index 0dce60d89..c5b575e1c 100644 --- a/pkg/tcpip/transport/internal/network/endpoint.go +++ b/pkg/tcpip/transport/internal/network/endpoint.go @@ -60,10 +60,8 @@ type Endpoint struct { multicastAddr tcpip.Address // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. multicastNICID tcpip.NICID - // sendTOS represents IPv4 TOS or IPv6 TrafficClass, - // applied while sending packets. Defaults to 0 as on Linux. - // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6. - sendTOS uint8 + ipv4TOS uint8 + ipv6TClass uint8 } // +stateify savable @@ -267,11 +265,21 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext return WriteContext{}, &tcpip.ErrBroadcastDisabled{} } + var tos uint8 + switch netProto := route.NetProto(); netProto { + case header.IPv4ProtocolNumber: + tos = e.ipv4TOS + case header.IPv6ProtocolNumber: + tos = e.ipv6TClass + default: + panic(fmt.Sprintf("invalid protocol number = %d", netProto)) + } + return WriteContext{ transProto: e.transProto, route: route, ttl: calculateTTL(route, e.ttl, e.multicastTTL), - tos: e.sendTOS, + tos: tos, owner: e.owner, }, nil } @@ -533,12 +541,12 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { case tcpip.IPv4TOSOption: e.mu.Lock() - e.sendTOS = uint8(v) + e.ipv4TOS = uint8(v) e.mu.Unlock() case tcpip.IPv6TrafficClassOption: e.mu.Lock() - e.sendTOS = uint8(v) + e.ipv6TClass = uint8(v) e.mu.Unlock() } @@ -566,13 +574,13 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { case tcpip.IPv4TOSOption: e.mu.RLock() - v := int(e.sendTOS) + v := int(e.ipv4TOS) e.mu.RUnlock() return v, nil case tcpip.IPv6TrafficClassOption: e.mu.RLock() - v := int(e.sendTOS) + v := int(e.ipv6TClass) e.mu.RUnlock() return v, nil diff --git a/pkg/tcpip/transport/internal/network/network_state_autogen.go b/pkg/tcpip/transport/internal/network/network_state_autogen.go index 8f1cf9c0d..0ce695bb8 100644 --- a/pkg/tcpip/transport/internal/network/network_state_autogen.go +++ b/pkg/tcpip/transport/internal/network/network_state_autogen.go @@ -25,7 +25,8 @@ func (e *Endpoint) StateFields() []string { "multicastTTL", "multicastAddr", "multicastNICID", - "sendTOS", + "ipv4TOS", + "ipv6TClass", } } @@ -47,7 +48,8 @@ func (e *Endpoint) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(10, &e.multicastTTL) stateSinkObject.Save(11, &e.multicastAddr) stateSinkObject.Save(12, &e.multicastNICID) - stateSinkObject.Save(13, &e.sendTOS) + stateSinkObject.Save(13, &e.ipv4TOS) + stateSinkObject.Save(14, &e.ipv6TClass) } func (e *Endpoint) afterLoad() {} @@ -67,7 +69,8 @@ func (e *Endpoint) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(10, &e.multicastTTL) stateSourceObject.Load(11, &e.multicastAddr) stateSourceObject.Load(12, &e.multicastNICID) - stateSourceObject.Load(13, &e.sendTOS) + stateSourceObject.Load(13, &e.ipv4TOS) + stateSourceObject.Load(14, &e.ipv6TClass) } func (m *multicastMembership) StateTypeName() string { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4b6bdc3be..f171a16f8 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -33,6 +33,7 @@ import ( // +stateify savable type udpPacket struct { udpPacketEntry + netProto tcpip.NetworkProtocolNumber senderAddress tcpip.FullAddress destinationAddress tcpip.FullAddress packetInfo tcpip.IPPacketInfo @@ -235,14 +236,21 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult HasTimestamp: true, Timestamp: p.receivedAt.UnixNano(), } - if e.ops.GetReceiveTOS() { - cm.HasTOS = true - cm.TOS = p.tos - } - if e.ops.GetReceiveTClass() { - cm.HasTClass = true - // Although TClass is an 8-bit value it's read in the CMsg as a uint32. - cm.TClass = uint32(p.tos) + + switch p.netProto { + case header.IPv4ProtocolNumber: + if e.ops.GetReceiveTOS() { + cm.HasTOS = true + cm.TOS = p.tos + } + case header.IPv6ProtocolNumber: + if e.ops.GetReceiveTClass() { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } + default: + panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto)) } if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true @@ -888,6 +896,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ + netProto: pkt.NetworkProtocolNumber, senderAddress: tcpip.FullAddress{ NIC: pkt.NICID, Addr: id.RemoteAddress, diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go index 58584c88c..e25607e3f 100644 --- a/pkg/tcpip/transport/udp/udp_state_autogen.go +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -14,6 +14,7 @@ func (p *udpPacket) StateTypeName() string { func (p *udpPacket) StateFields() []string { return []string{ "udpPacketEntry", + "netProto", "senderAddress", "destinationAddress", "packetInfo", @@ -30,15 +31,16 @@ func (p *udpPacket) StateSave(stateSinkObject state.Sink) { p.beforeSave() var dataValue buffer.VectorisedView dataValue = p.saveData() - stateSinkObject.SaveValue(4, dataValue) + stateSinkObject.SaveValue(5, dataValue) var receivedAtValue int64 receivedAtValue = p.saveReceivedAt() - stateSinkObject.SaveValue(5, receivedAtValue) + stateSinkObject.SaveValue(6, receivedAtValue) stateSinkObject.Save(0, &p.udpPacketEntry) - stateSinkObject.Save(1, &p.senderAddress) - stateSinkObject.Save(2, &p.destinationAddress) - stateSinkObject.Save(3, &p.packetInfo) - stateSinkObject.Save(6, &p.tos) + stateSinkObject.Save(1, &p.netProto) + stateSinkObject.Save(2, &p.senderAddress) + stateSinkObject.Save(3, &p.destinationAddress) + stateSinkObject.Save(4, &p.packetInfo) + stateSinkObject.Save(7, &p.tos) } func (p *udpPacket) afterLoad() {} @@ -46,12 +48,13 @@ func (p *udpPacket) afterLoad() {} // +checklocksignore func (p *udpPacket) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(0, &p.udpPacketEntry) - stateSourceObject.Load(1, &p.senderAddress) - stateSourceObject.Load(2, &p.destinationAddress) - stateSourceObject.Load(3, &p.packetInfo) - stateSourceObject.Load(6, &p.tos) - stateSourceObject.LoadValue(4, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) - stateSourceObject.LoadValue(5, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) + stateSourceObject.Load(1, &p.netProto) + stateSourceObject.Load(2, &p.senderAddress) + stateSourceObject.Load(3, &p.destinationAddress) + stateSourceObject.Load(4, &p.packetInfo) + stateSourceObject.Load(7, &p.tos) + stateSourceObject.LoadValue(5, new(buffer.VectorisedView), func(y interface{}) { p.loadData(y.(buffer.VectorisedView)) }) + stateSourceObject.LoadValue(6, new(int64), func(y interface{}) { p.loadReceivedAt(y.(int64)) }) } func (e *endpoint) StateTypeName() string { |