diff options
Diffstat (limited to 'pkg/tcpip/transport/udp')
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 122 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 55 |
3 files changed, 38 insertions, 141 deletions
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9d33a694b..a9c74148b 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -101,12 +101,10 @@ type endpoint struct { state EndpointState route *stack.Route `state:"manual"` dstPort uint16 - v6only bool ttl uint8 multicastTTL uint8 multicastAddr tcpip.Address multicastNICID tcpip.NICID - multicastLoop bool portFlags ports.Flags bindToDevice tcpip.NICID @@ -122,17 +120,6 @@ type endpoint struct { // applied while sending packets. Defaults to 0 as on Linux. sendTOS uint8 - // receiveTOS determines if the incoming IPv4 TOS header field is passed - // as ancillary data to ControlMessages on Read. - receiveTOS bool - - // receiveTClass determines if the incoming IPv6 TClass header field is - // passed as ancillary data to ControlMessages on Read. - receiveTClass bool - - // receiveIPPacketInfo determines if the packet info is returned by Read. - receiveIPPacketInfo bool - // shutdownFlags represent the current shutdown state of the endpoint. shutdownFlags tcpip.ShutdownFlags @@ -188,7 +175,6 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue // // Linux defaults to TTL=1. multicastTTL: 1, - multicastLoop: true, rcvBufSizeMax: 32 * 1024, sndBufSizeMax: 32 * 1024, multicastMemberships: make(map[multicastMembership]struct{}), @@ -196,6 +182,7 @@ func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQue uniqueID: s.UniqueID(), } e.ops.InitHandler(e) + e.ops.SetMulticastLoop(true) // Override with stack defaults. var ss stack.SendBufferSizeOption @@ -307,21 +294,16 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess HasTimestamp: true, Timestamp: p.timestamp, } - e.mu.RLock() - receiveTOS := e.receiveTOS - receiveTClass := e.receiveTClass - receiveIPPacketInfo := e.receiveIPPacketInfo - e.mu.RUnlock() - if receiveTOS { + if e.ops.GetReceiveTOS() { cm.HasTOS = true cm.TOS = p.tos } - if receiveTClass { + if e.ops.GetReceiveTClass() { cm.HasTClass = true // Although TClass is an 8-bit value it's read in the CMsg as a uint32. cm.TClass = uint32(p.tos) } - if receiveIPPacketInfo { + if e.ops.GetReceivePacketInfo() { cm.HasIPPacketInfo = true cm.PacketInfo = p.packetInfo } @@ -388,7 +370,7 @@ func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netPr } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.ops.GetMulticastLoop()) if err != nil { return nil, 0, err } @@ -595,48 +577,6 @@ func (e *endpoint) OnReusePortSet(v bool) { // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.MulticastLoopOption: - e.mu.Lock() - e.multicastLoop = v - e.mu.Unlock() - - case tcpip.ReceiveTOSOption: - e.mu.Lock() - e.receiveTOS = v - e.mu.Unlock() - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrNotSupported - } - - e.mu.Lock() - e.receiveTClass = v - e.mu.Unlock() - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.Lock() - e.receiveIPPacketInfo = v - e.mu.Unlock() - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return tcpip.ErrInvalidEndpointState - } - - e.mu.Lock() - defer e.mu.Unlock() - - // We only allow this to be set when we're in the initial state. - if e.state != StateInitial { - return tcpip.ErrInvalidEndpointState - } - - e.v6only = v - } return nil } @@ -851,51 +791,7 @@ func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.MulticastLoopOption: - e.mu.RLock() - v := e.multicastLoop - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTOSOption: - e.mu.RLock() - v := e.receiveTOS - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveTClassOption: - // We only support this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrNotSupported - } - - e.mu.RLock() - v := e.receiveTClass - e.mu.RUnlock() - return v, nil - - case tcpip.ReceiveIPPacketInfoOption: - e.mu.RLock() - v := e.receiveIPPacketInfo - e.mu.RUnlock() - return v, nil - - case tcpip.V6OnlyOption: - // We only recognize this option on v6 endpoints. - if e.NetProto != header.IPv6ProtocolNumber { - return false, tcpip.ErrUnknownProtocolOption - } - - e.mu.RLock() - v := e.v6only - e.mu.RUnlock() - - return v, nil - - default: - return false, tcpip.ErrUnknownProtocolOption - } + return false, tcpip.ErrUnknownProtocolOption } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. @@ -1036,7 +932,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // checkV4MappedLocked determines the effective network protocol and converts // addr to its canonical form. func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.ops.GetV6Only()) if err != nil { return tcpip.FullAddress{}, 0, err } @@ -1147,7 +1043,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // packets on a different network protocol, so we register both even if // v6only is set to false and this is an ipv6 endpoint. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv4ProtocolNumber, header.IPv6ProtocolNumber, @@ -1259,7 +1155,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { // wildcard (empty) address, and this is an IPv6 endpoint with v6only // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + if netProto == header.IPv6ProtocolNumber && !e.ops.GetV6Only() && addr.Addr == "" { netProtos = []tcpip.NetworkProtocolNumber{ header.IPv6ProtocolNumber, header.IPv4ProtocolNumber, diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index 99f3fc37f..9d06035ea 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -114,7 +114,7 @@ func (e *endpoint) Resume(s *stack.Stack) { var err *tcpip.Error if e.state == StateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.ops.GetMulticastLoop()) if err != nil { panic(err) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1233bab14..e384f52dd 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -363,9 +363,7 @@ func (c *testContext) createEndpointForFlow(flow testFlow) { c.createEndpoint(flow.sockProto()) if flow.isV6Only() { - if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { - c.t.Fatalf("SetSockOptBool failed: %s", err) - } + c.ep.SocketOptions().SetV6Only(true) } else if flow.isBroadcast() { c.ep.SocketOptions().SetBroadcast(true) } @@ -1414,9 +1412,7 @@ func TestReadIPPacketInfo(t *testing.T) { } } - if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { - t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) - } + c.ep.SocketOptions().SetReceivePacketInfo(true) testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ NIC: 1, @@ -1629,13 +1625,15 @@ func TestSetTClass(t *testing.T) { } func TestReceiveTosTClass(t *testing.T) { + const RcvTOSOpt = "ReceiveTosOption" + const RcvTClassOpt = "ReceiveTClassOption" + testCases := []struct { - name string - getReceiveOption tcpip.SockOptBool - tests []testFlow + name string + tests []testFlow }{ - {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, - {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + {RcvTOSOpt, []testFlow{unicastV4, broadcast}}, + {RcvTClassOpt, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, } for _, testCase := range testCases { for _, flow := range testCase.tests { @@ -1644,29 +1642,32 @@ func TestReceiveTosTClass(t *testing.T) { defer c.cleanup() c.createEndpointForFlow(flow) - option := testCase.getReceiveOption name := testCase.name - // Verify that setting and reading the option works. - v, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + var optionGetter func() bool + var optionSetter func(bool) + switch name { + case RcvTOSOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTOS + optionSetter = c.ep.SocketOptions().SetReceiveTOS + case RcvTClassOpt: + optionGetter = c.ep.SocketOptions().GetReceiveTClass + optionSetter = c.ep.SocketOptions().SetReceiveTClass + default: + t.Fatalf("unkown test variant: %s", name) } + + // Verify that setting and reading the option works. + v := optionGetter() // Test for expected default value. if v != false { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) } want := true - if err := c.ep.SetSockOptBool(option, want); err != nil { - c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) - } - - got, err := c.ep.GetSockOptBool(option) - if err != nil { - c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) - } + optionSetter(want) + got := optionGetter() if got != want { c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) } @@ -1676,10 +1677,10 @@ func TestReceiveTosTClass(t *testing.T) { if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { c.t.Fatalf("Bind failed: %s", err) } - switch option { - case tcpip.ReceiveTClassOption: + switch name { + case RcvTClassOpt: testRead(c, flow, checker.ReceiveTClass(testTOS)) - case tcpip.ReceiveTOSOption: + case RcvTOSOpt: testRead(c, flow, checker.ReceiveTOS(testTOS)) default: t.Fatalf("unknown test variant: %s", name) |