summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/transport/internal/network/endpoint.go26
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go25
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go46
3 files changed, 68 insertions, 29 deletions
diff --git a/pkg/tcpip/transport/internal/network/endpoint.go b/pkg/tcpip/transport/internal/network/endpoint.go
index 0dce60d89..c5b575e1c 100644
--- a/pkg/tcpip/transport/internal/network/endpoint.go
+++ b/pkg/tcpip/transport/internal/network/endpoint.go
@@ -60,10 +60,8 @@ type Endpoint struct {
multicastAddr tcpip.Address
// TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
multicastNICID tcpip.NICID
- // sendTOS represents IPv4 TOS or IPv6 TrafficClass,
- // applied while sending packets. Defaults to 0 as on Linux.
- // TODO(https://gvisor.dev/issue/6389): Use different fields for IPv4/IPv6.
- sendTOS uint8
+ ipv4TOS uint8
+ ipv6TClass uint8
}
// +stateify savable
@@ -267,11 +265,21 @@ func (e *Endpoint) AcquireContextForWrite(opts tcpip.WriteOptions) (WriteContext
return WriteContext{}, &tcpip.ErrBroadcastDisabled{}
}
+ var tos uint8
+ switch netProto := route.NetProto(); netProto {
+ case header.IPv4ProtocolNumber:
+ tos = e.ipv4TOS
+ case header.IPv6ProtocolNumber:
+ tos = e.ipv6TClass
+ default:
+ panic(fmt.Sprintf("invalid protocol number = %d", netProto))
+ }
+
return WriteContext{
transProto: e.transProto,
route: route,
ttl: calculateTTL(route, e.ttl, e.multicastTTL),
- tos: e.sendTOS,
+ tos: tos,
owner: e.owner,
}, nil
}
@@ -533,12 +541,12 @@ func (e *Endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error {
case tcpip.IPv4TOSOption:
e.mu.Lock()
- e.sendTOS = uint8(v)
+ e.ipv4TOS = uint8(v)
e.mu.Unlock()
case tcpip.IPv6TrafficClassOption:
e.mu.Lock()
- e.sendTOS = uint8(v)
+ e.ipv6TClass = uint8(v)
e.mu.Unlock()
}
@@ -566,13 +574,13 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
case tcpip.IPv4TOSOption:
e.mu.RLock()
- v := int(e.sendTOS)
+ v := int(e.ipv4TOS)
e.mu.RUnlock()
return v, nil
case tcpip.IPv6TrafficClassOption:
e.mu.RLock()
- v := int(e.sendTOS)
+ v := int(e.ipv6TClass)
e.mu.RUnlock()
return v, nil
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4b6bdc3be..f171a16f8 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -33,6 +33,7 @@ import (
// +stateify savable
type udpPacket struct {
udpPacketEntry
+ netProto tcpip.NetworkProtocolNumber
senderAddress tcpip.FullAddress
destinationAddress tcpip.FullAddress
packetInfo tcpip.IPPacketInfo
@@ -235,14 +236,21 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult
HasTimestamp: true,
Timestamp: p.receivedAt.UnixNano(),
}
- if e.ops.GetReceiveTOS() {
- cm.HasTOS = true
- cm.TOS = p.tos
- }
- if e.ops.GetReceiveTClass() {
- cm.HasTClass = true
- // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
- cm.TClass = uint32(p.tos)
+
+ switch p.netProto {
+ case header.IPv4ProtocolNumber:
+ if e.ops.GetReceiveTOS() {
+ cm.HasTOS = true
+ cm.TOS = p.tos
+ }
+ case header.IPv6ProtocolNumber:
+ if e.ops.GetReceiveTClass() {
+ cm.HasTClass = true
+ // Although TClass is an 8-bit value it's read in the CMsg as a uint32.
+ cm.TClass = uint32(p.tos)
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized network protocol = %d", p.netProto))
}
if e.ops.GetReceivePacketInfo() {
cm.HasIPPacketInfo = true
@@ -888,6 +896,7 @@ func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketB
// Push new packet into receive list and increment the buffer size.
packet := &udpPacket{
+ netProto: pkt.NetworkProtocolNumber,
senderAddress: tcpip.FullAddress{
NIC: pkt.NICID,
Addr: id.RemoteAddress,
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 4008cacf2..554ce1de4 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -290,6 +290,7 @@ type testContext struct {
t *testing.T
linkEP *channel.Endpoint
s *stack.Stack
+ nicID tcpip.NICID
ep tcpip.Endpoint
wq waiter.Queue
@@ -301,6 +302,8 @@ func newDualTestContext(t *testing.T, mtu uint32) *testContext {
}
func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal bool) *testContext {
+ const nicID = 1
+
t.Helper()
options := stack.Options{
@@ -316,32 +319,33 @@ func newDualTestContextWithHandleLocal(t *testing.T, mtu uint32, handleLocal boo
if testing.Verbose() {
wep = sniffer.New(ep)
}
- if err := s.CreateNIC(1, wep); err != nil {
- t.Fatalf("CreateNIC failed: %s", err)
+ if err := s.CreateNIC(nicID, wep); err != nil {
+ t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
}
- if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ if err := s.AddAddress(nicID, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, stackAddr, err)
}
- if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
- t.Fatalf("AddAddress failed: %s", err)
+ if err := s.AddAddress(nicID, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress((%d, %d, %s): %s", nicID, ipv6.ProtocolNumber, stackV6Addr, err)
}
s.SetRouteTable([]tcpip.Route{
{
Destination: header.IPv4EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
{
Destination: header.IPv6EmptySubnet,
- NIC: 1,
+ NIC: nicID,
},
})
return &testContext{
t: t,
s: s,
+ nicID: nicID,
linkEP: ep,
}
}
@@ -1644,8 +1648,10 @@ func TestSetTTL(t *testing.T) {
}
}
+var v4PacketFlows = [...]testFlow{unicastV4, multicastV4, broadcast, unicastV4in6, multicastV4in6, broadcastIn6}
+
func TestSetTOS(t *testing.T) {
- for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} {
+ for _, flow := range v4PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1680,8 +1686,10 @@ func TestSetTOS(t *testing.T) {
}
}
+var v6PacketFlows = [...]testFlow{unicastV6, unicastV6Only, multicastV6}
+
func TestSetTClass(t *testing.T) {
- for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} {
+ for _, flow := range v6PacketFlows {
t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) {
c := newDualTestContext(t, defaultMTU)
defer c.cleanup()
@@ -1725,8 +1733,14 @@ func TestReceiveTosTClass(t *testing.T) {
name string
tests []testFlow
}{
- {RcvTOSOpt, []testFlow{unicastV4, broadcast}},
- {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}},
+ {
+ name: RcvTOSOpt,
+ tests: v4PacketFlows[:],
+ },
+ {
+ name: RcvTClassOpt,
+ tests: v6PacketFlows[:],
+ },
}
for _, testCase := range testCases {
for _, flow := range testCase.tests {
@@ -1737,6 +1751,14 @@ func TestReceiveTosTClass(t *testing.T) {
c.createEndpointForFlow(flow)
name := testCase.name
+ if flow.isMulticast() {
+ netProto := flow.netProto()
+ addr := flow.getMcastAddr()
+ if err := c.s.JoinGroup(netProto, c.nicID, addr); err != nil {
+ c.t.Fatalf("JoinGroup(%d, %d, %s): %s", netProto, c.nicID, addr, err)
+ }
+ }
+
var optionGetter func() bool
var optionSetter func(bool)
switch name {