From 3ff1aef544b532c207cf55bcee64fe6717bfd3c4 Mon Sep 17 00:00:00 2001 From: Peter Johnston Date: Thu, 3 Dec 2020 08:52:05 -0800 Subject: Make `stack.Route` thread safe Currently we rely on the user to take the lock on the endpoint that owns the route, in order to modify it safely. We can instead move `Route.RemoteLinkAddress` under `Route`'s mutex, and allow non-locking and thread-safe access to other fields of `Route`. PiperOrigin-RevId: 345461586 --- pkg/tcpip/network/arp/arp_test.go | 8 ++++---- pkg/tcpip/network/ipv4/ipv4_test.go | 12 ++++++------ pkg/tcpip/network/ipv6/icmp_test.go | 24 ++++++++++++------------ pkg/tcpip/network/ipv6/ndp_test.go | 8 ++++---- 4 files changed, 26 insertions(+), 26 deletions(-) (limited to 'pkg/tcpip/network') diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index a738e9e1c..0fb373612 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -442,9 +442,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -557,8 +557,8 @@ func TestLinkAddressRequest(t *testing.T) { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 9b660eef0..2d633ca23 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -2770,8 +2770,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2813,8 +2813,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != header.IPv4ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, header.IPv4ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), @@ -2862,8 +2862,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != arp.ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, arp.ProtocolNumber) } - if p.Route.RemoteLinkAddress != header.EthernetBroadcastAddress { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, header.EthernetBroadcastAddress) + if got := p.Route.RemoteLinkAddress(); got != header.EthernetBroadcastAddress { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, header.EthernetBroadcastAddress) } rep := header.ARP(p.Pkt.NetworkHeader().View()) if got := rep.Op(); got != header.ARPRequest { diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index f19b7bec9..32adb5c83 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -150,9 +150,9 @@ func (*testInterface) Promiscuous() bool { func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { r := stack.Route{ - NetProto: protocol, - RemoteLinkAddress: remoteLinkAddr, + NetProto: protocol, } + r.ResolveWith(remoteLinkAddr) return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) } @@ -600,8 +600,8 @@ func routeICMPv6Packet(t *testing.T, args routeArgs, fn func(*testing.T, header. return } - if len(args.remoteLinkAddr) != 0 && args.remoteLinkAddr != pi.Route.RemoteLinkAddress { - t.Errorf("got remote link address = %s, want = %s", pi.Route.RemoteLinkAddress, args.remoteLinkAddr) + if got := pi.Route.RemoteLinkAddress(); len(args.remoteLinkAddr) != 0 && got != args.remoteLinkAddr { + t.Errorf("got remote link address = %s, want = %s", got, args.remoteLinkAddr) } // Pull the full payload since network header. Needed for header.IPv6 to @@ -1381,8 +1381,8 @@ func TestLinkAddressRequest(t *testing.T) { if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + if got := pkt.Route.RemoteLinkAddress(); got != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress() = %s, want = %s", got, test.expectedRemoteLinkAddr) } if pkt.Route.RemoteAddress != test.expectedRemoteAddr { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) @@ -1463,8 +1463,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1505,8 +1505,8 @@ func TestPacketQueing(t *testing.T) { if p.Proto != ProtocolNumber { t.Errorf("got p.Proto = %d, want = %d", p.Proto, ProtocolNumber) } - if p.Route.RemoteLinkAddress != host2NICLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, host2NICLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != host2NICLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, host2NICLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), @@ -1556,8 +1556,8 @@ func TestPacketQueing(t *testing.T) { t.Errorf("got Proto = %d, want = %d", p.Proto, ProtocolNumber) } snmc := header.SolicitedNodeAddr(host2IPv6Addr.AddressWithPrefix.Address) - if want := header.EthernetAddressFromMulticastIPv6Address(snmc); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(snmc); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index e2409306f..95c626bb8 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -650,8 +650,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != respNSDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, respNSDst) } - if want := header.EthernetAddressFromMulticastIPv6Address(respNSDst); p.Route.RemoteLinkAddress != want { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, want) + if got, want := p.Route.RemoteLinkAddress(), header.EthernetAddressFromMulticastIPv6Address(respNSDst); got != want { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, want) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), @@ -706,8 +706,8 @@ func TestNeighorSolicitationResponse(t *testing.T) { if p.Route.RemoteAddress != test.naDst { t.Errorf("got p.Route.RemoteAddress = %s, want = %s", p.Route.RemoteAddress, test.naDst) } - if p.Route.RemoteLinkAddress != test.naDstLinkAddr { - t.Errorf("got p.Route.RemoteLinkAddress = %s, want = %s", p.Route.RemoteLinkAddress, test.naDstLinkAddr) + if got := p.Route.RemoteLinkAddress(); got != test.naDstLinkAddr { + t.Errorf("got p.Route.RemoteLinkAddress() = %s, want = %s", got, test.naDstLinkAddr) } checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), -- cgit v1.2.3