From ec0aa657edfd98a1e8dfbbf017ee6cf8c7f1a40e Mon Sep 17 00:00:00 2001 From: Nick Brown Date: Wed, 24 Mar 2021 09:36:50 -0700 Subject: Unexpose immutable fields in stack.Route This change sets the inner `routeInfo` struct to be a named private member and replaces direct access with access through getters. Note that direct access to the fields of `routeInfo` is still possible through the `RouteInfo` struct. Fixes #4902 PiperOrigin-RevId: 364822872 --- pkg/tcpip/transport/icmp/endpoint.go | 8 ++++---- pkg/tcpip/transport/icmp/endpoint_state.go | 2 +- pkg/tcpip/transport/raw/endpoint.go | 2 +- pkg/tcpip/transport/tcp/connect.go | 2 +- pkg/tcpip/transport/tcp/endpoint.go | 6 +++--- pkg/tcpip/transport/udp/endpoint.go | 14 +++++++------- 6 files changed, 17 insertions(+), 17 deletions(-) (limited to 'pkg/tcpip/transport') diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 06c63e74a..1dce35c63 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -467,8 +467,8 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress, - Dst: r.RemoteAddress, + Src: r.LocalAddress(), + Dst: r.RemoteAddress(), PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) @@ -536,9 +536,9 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { } id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress, + LocalAddress: r.LocalAddress(), LocalPort: localPort, - RemoteAddress: r.RemoteAddress, + RemoteAddress: r.RemoteAddress(), } // Even if we're connected, this endpoint can still be used to send diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index c9fa9974a..a3c6db5a8 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -82,7 +82,7 @@ func (e *endpoint) Resume(s *stack.Stack) { panic(err) } - e.ID.LocalAddress = e.route.LocalAddress + e.ID.LocalAddress = e.route.LocalAddress() } else if len(e.ID.LocalAddress) != 0 { // stateBound if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { panic(&tcpip.ErrBadLocalAddress{}) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 2709be90c..4b2f08379 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -614,7 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != remoteAddr { + if e.connected && e.route.RemoteAddress() != remoteAddr { e.rcvMu.Unlock() e.mu.RUnlock() return diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 3404af6bb..b32fe2fb1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -811,7 +811,7 @@ func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stac tf.rcvWnd = math.MaxUint16 } - if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + if r.Loop()&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { return sendTCPBatch(r, tf, data, gso, owner) } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 43d344350..0a5e9cbb4 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2211,8 +2211,8 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) tcp defer r.Release() netProtos := []tcpip.NetworkProtocolNumber{netProto} - e.ID.LocalAddress = r.LocalAddress - e.ID.RemoteAddress = r.RemoteAddress + e.ID.LocalAddress = r.LocalAddress() + e.ID.RemoteAddress = r.RemoteAddress() e.ID.RemotePort = addr.Port if e.ID.LocalPort != 0 { @@ -3102,7 +3102,7 @@ func (e *endpoint) completeState() stack.TCPEndpointState { func (e *endpoint) initHardwareGSO() { gso := &stack.GSO{} - switch e.route.NetProto { + switch e.route.NetProto() { case header.IPv4ProtocolNumber: gso.Type = stack.GSOTCPv4 gso.L3HdrLen = header.IPv4MinimumSize diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index c0f566459..0f59181bb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -534,11 +534,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp if so.GetRecvError() { so.QueueLocalErr( &tcpip.ErrMessageTooLong{}, - route.NetProto, + route.NetProto(), header.UDPMaximumPacketSize, tcpip.FullAddress{ NIC: route.NICID(), - Addr: route.RemoteAddress, + Addr: route.RemoteAddress(), Port: dstPort, }, v, @@ -550,7 +550,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp ttl := e.ttl useDefaultTTL := ttl == 0 - if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + if header.IsV4MulticastAddress(route.RemoteAddress()) || header.IsV6MulticastAddress(route.RemoteAddress()) { ttl = e.multicastTTL // Multicast allows a 0 TTL. useDefaultTTL = false @@ -861,7 +861,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). if r.RequiresTXTransportChecksum() && - (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { + (!noChecksum || r.NetProto() == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { xsum = header.Checksum(v, xsum) @@ -992,11 +992,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { LocalAddress: e.ID.LocalAddress, LocalPort: localPort, RemotePort: addr.Port, - RemoteAddress: r.RemoteAddress, + RemoteAddress: r.RemoteAddress(), } if e.EndpointState() == StateInitial { - id.LocalAddress = r.LocalAddress + id.LocalAddress = r.LocalAddress() } // Even if we're connected, this endpoint can still be used to send @@ -1204,7 +1204,7 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { addr := e.ID.LocalAddress if e.EndpointState() == StateConnected { - addr = e.route.LocalAddress + addr = e.route.LocalAddress() } return tcpip.FullAddress{ -- cgit v1.2.3