summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-10-27 13:39:24 -0700
committergVisor bot <gvisor-bot@google.com>2021-10-27 13:41:53 -0700
commit3015c0ac67ef7703899e753121efe326dc0cbecd (patch)
tree58dc69fb3f6fde39bb4fd8d0c18a79906c6bcfa4
parent22a6a37079c69129d10abfbdd6fdfdf7a9d4a68d (diff)
NAT ICMPv4 errors
...so a NAT-ed connection's socket can handle ICMP errors. Updates #5916. PiperOrigin-RevId: 405970089
-rw-r--r--pkg/tcpip/stack/conntrack.go156
-rw-r--r--pkg/tcpip/tests/integration/iptables_test.go293
-rw-r--r--pkg/tcpip/tests/utils/utils.go114
3 files changed, 495 insertions, 68 deletions
diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go
index a3f403855..4a28be585 100644
--- a/pkg/tcpip/stack/conntrack.go
+++ b/pkg/tcpip/stack/conntrack.go
@@ -209,19 +209,120 @@ type bucket struct {
tuples tupleList
}
-func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) {
+func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.ChecksummableTransport, isICMPError bool, ok bool) {
switch pkt.TransportProtocolNumber {
case header.TCPProtocolNumber:
if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize {
- return tcpHeader, true
+ return pkt.Network(), tcpHeader, false, true
}
case header.UDPProtocolNumber:
if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize {
- return udpHeader, true
+ return pkt.Network(), udpHeader, false, true
+ }
+ case header.ICMPv4ProtocolNumber:
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
+ if !ok {
+ panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize))
+ }
+
+ ipv4 := header.IPv4(h)
+ if ipv4.HeaderLength() > header.IPv4MinimumSize {
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ panic("should have dropped packets with IPv4 options")
+ }
+
+ switch pkt.tuple.id().transProto {
+ case header.TCPProtocolNumber:
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize)
+ if !ok {
+ return nil, nil, false, false
+ }
+ netHeader := header.IPv4(netAndTransHeader)
+ return netHeader, header.TCP(netHeader.Payload()), true, true
+ case header.UDPProtocolNumber:
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize)
+ if !ok {
+ return nil, nil, false, false
+ }
+ netHeader := header.IPv4(netAndTransHeader)
+ return netHeader, header.UDP(netHeader.Payload()), true, true
}
}
- return nil, false
+ return nil, nil, false, false
+}
+
+func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
+ return tupleID{
+ srcAddr: netHdr.SourceAddress(),
+ srcPort: transHdr.SourcePort(),
+ dstAddr: netHdr.DestinationAddress(),
+ dstPort: transHdr.DestinationPort(),
+ transProto: transProto,
+ netProto: netProto,
+ }
+}
+
+func getTupleIDForPacketInICMPError(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
+ return tupleID{
+ srcAddr: netHdr.DestinationAddress(),
+ srcPort: transHdr.DestinationPort(),
+ dstAddr: netHdr.SourceAddress(),
+ dstPort: transHdr.SourcePort(),
+ transProto: transProto,
+ netProto: netProto,
+ }
+}
+
+func getTupleID(pkt *PacketBuffer) (tid tupleID, isICMPError bool, ok bool) {
+ switch pkt.TransportProtocolNumber {
+ case header.TCPProtocolNumber:
+ if transHeader := header.TCP(pkt.TransportHeader().View()); len(transHeader) >= header.TCPMinimumSize {
+ return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
+ }
+ case header.UDPProtocolNumber:
+ if transHeader := header.UDP(pkt.TransportHeader().View()); len(transHeader) >= header.UDPMinimumSize {
+ return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), false, true
+ }
+ case header.ICMPv4ProtocolNumber:
+ icmp := header.ICMPv4(pkt.TransportHeader().View())
+ if len(icmp) < header.ICMPv4MinimumSize {
+ return tupleID{}, false, false
+ }
+
+ switch icmp.Type() {
+ case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
+ default:
+ return tupleID{}, false, false
+ }
+
+ h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
+ if !ok {
+ return tupleID{}, false, false
+ }
+
+ ipv4 := header.IPv4(h)
+ if ipv4.HeaderLength() > header.IPv4MinimumSize {
+ // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
+ return tupleID{}, false, false
+ }
+ switch ipv4.TransportProtocol() {
+ case header.TCPProtocolNumber:
+ if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.TCPMinimumSize); ok {
+ netHdr := header.IPv4(netAndTransHeader)
+ return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.TCP(netHdr.Payload()), header.TCPProtocolNumber), true, true
+ }
+ case header.UDPProtocolNumber:
+ if netAndTransHeader, ok := pkt.Data().PullUp(header.IPv4MinimumSize + header.UDPMinimumSize); ok {
+ netHdr := header.IPv4(netAndTransHeader)
+ return getTupleIDForPacketInICMPError(netHdr, header.IPv4ProtocolNumber, header.UDP(netHdr.Payload()), header.UDPProtocolNumber), true, true
+ }
+ }
+ }
+
+ return tupleID{}, false, false
}
func (ct *ConnTrack) init() {
@@ -231,21 +332,11 @@ func (ct *ConnTrack) init() {
}
func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
- netHeader := pkt.Network()
- transportHeader, ok := getTransportHeader(pkt)
+ tid, isICMPError, ok := getTupleID(pkt)
if !ok {
return nil
}
- tid := tupleID{
- srcAddr: netHeader.SourceAddress(),
- srcPort: transportHeader.SourcePort(),
- dstAddr: netHeader.DestinationAddress(),
- dstPort: transportHeader.DestinationPort(),
- transProto: pkt.TransportProtocolNumber,
- netProto: pkt.NetworkProtocolNumber,
- }
-
bktID := ct.bucket(tid)
ct.mu.RLock()
@@ -257,6 +348,11 @@ func (ct *ConnTrack) getConnOrMaybeInsertNoop(pkt *PacketBuffer) *tuple {
return t
}
+ if isICMPError {
+ // Do not create a noop entry in response to an ICMP error.
+ return nil
+ }
+
bkt.mu.Lock()
defer bkt.mu.Unlock()
@@ -407,7 +503,7 @@ func (cn *conn) performNATIfNoop(port uint16, address tcpip.Address, dnat bool)
//
// Returns true if the packet can skip the NAT table.
func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
- transportHeader, ok := getTransportHeader(pkt)
+ netHdr, transHdr, isICMPError, ok := getHeaders(pkt)
if !ok {
return false
}
@@ -498,9 +594,9 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
}
rewritePacket(
- pkt.Network(),
- transportHeader,
- !dnat,
+ netHdr,
+ transHdr,
+ !dnat != isICMPError,
fullChecksum,
updatePseudoHeader,
newPort,
@@ -508,6 +604,28 @@ func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
)
*natDone = true
+
+ if !isICMPError {
+ return true
+ }
+
+ // We performed NAT on (erroneous) packet that triggered an ICMP response, but
+ // not the ICMP packet itself.
+ switch pkt.TransportProtocolNumber {
+ case header.ICMPv4ProtocolNumber:
+ icmp := header.ICMPv4(pkt.TransportHeader().View())
+ // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().AsRange().Checksum()))
+
+ network := header.IPv4(pkt.NetworkHeader().View())
+ if dnat {
+ network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr)
+ } else {
+ network.SetSourceAddressWithChecksumUpdate(tid.dstAddr)
+ }
+ }
+
return true
}
diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go
index 957a779bf..9e00a6350 100644
--- a/pkg/tcpip/tests/integration/iptables_test.go
+++ b/pkg/tcpip/tests/integration/iptables_test.go
@@ -1779,3 +1779,296 @@ func TestNAT(t *testing.T) {
})
}
}
+
+func TestNATICMPError(t *testing.T) {
+ const srcPort = 1234
+ const dstPort = 5432
+
+ type icmpTypeTest struct {
+ name string
+ val uint8
+ expectResponse bool
+ }
+
+ type transportTypeTest struct {
+ name string
+ proto tcpip.TransportProtocolNumber
+ buf buffer.View
+ checkNATed func(*testing.T, buffer.View)
+ }
+
+ ipHdr := func(v buffer.View, totalLen int, transProto tcpip.TransportProtocolNumber, srcAddr, dstAddr tcpip.Address) {
+ ip := header.IPv4(v)
+ ip.Encode(&header.IPv4Fields{
+ TotalLength: uint16(totalLen),
+ Protocol: uint8(transProto),
+ TTL: 64,
+ SrcAddr: srcAddr,
+ DstAddr: dstAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+ }
+
+ tests := []struct {
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ host1Addr tcpip.Address
+ icmpError func(*testing.T, buffer.View, uint8) buffer.View
+ decrementTTL func(buffer.View)
+ checkNATedError func(*testing.T, buffer.View, buffer.View, uint8)
+
+ transportTypes []transportTypeTest
+ icmpTypes []icmpTypeTest
+ }{
+ {
+ name: "IPv4",
+ netProto: ipv4.ProtocolNumber,
+ host1Addr: utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ icmpError: func(t *testing.T, original buffer.View, icmpType uint8) buffer.View {
+ totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(original)
+ hdr := buffer.NewPrependable(totalLen)
+ if n := copy(hdr.Prepend(len(original)), original); n != len(original) {
+ t.Fatalf("got copy(...) = %d, want = %d", n, len(original))
+ }
+ icmp := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmp.SetType(header.ICMPv4Type(icmpType))
+ icmp.SetChecksum(0)
+ icmp.SetChecksum(header.ICMPv4Checksum(icmp, 0))
+ ipHdr(hdr.Prepend(header.IPv4MinimumSize),
+ totalLen,
+ header.ICMPv4ProtocolNumber,
+ utils.Host1IPv4Addr.AddressWithPrefix.Address,
+ utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address,
+ )
+ return hdr.View()
+ },
+ decrementTTL: func(v buffer.View) {
+ ip := header.IPv4(v)
+ ip.SetTTL(ip.TTL() - 1)
+ ip.SetChecksum(0)
+ ip.SetChecksum(^ip.CalculateChecksum())
+ },
+ checkNATedError: func(t *testing.T, v buffer.View, original buffer.View, icmpType uint8) {
+ checker.IPv4(t, v,
+ checker.SrcAddr(utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(utils.Host2IPv4Addr.AddressWithPrefix.Address),
+ checker.ICMPv4(
+ checker.ICMPv4Type(header.ICMPv4Type(icmpType)),
+ checker.ICMPv4Checksum(),
+ checker.ICMPv4Payload(original),
+ ),
+ )
+ },
+ transportTypes: []transportTypeTest{
+ {
+ name: "UDP",
+ proto: header.UDPProtocolNumber,
+ buf: func() buffer.View {
+ totalLen := header.IPv4MinimumSize + header.UDPMinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+ udp.SetSourcePort(srcPort)
+ udp.SetDestinationPort(dstPort)
+ udp.SetChecksum(0)
+ udp.SetChecksum(^udp.CalculateChecksum(header.PseudoHeaderChecksum(
+ header.UDPProtocolNumber,
+ utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ uint16(len(udp)),
+ )))
+ ipHdr(hdr.Prepend(header.IPv4MinimumSize),
+ totalLen,
+ header.UDPProtocolNumber,
+ utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ )
+ return hdr.View()
+ }(),
+ checkNATed: func(t *testing.T, v buffer.View) {
+ checker.IPv4(t, v,
+ checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
+ checker.UDP(
+ checker.SrcPort(srcPort),
+ checker.DstPort(dstPort),
+ ),
+ )
+ },
+ },
+ {
+ name: "TCP",
+ proto: header.TCPProtocolNumber,
+ buf: func() buffer.View {
+ totalLen := header.IPv4MinimumSize + header.TCPMinimumSize
+ hdr := buffer.NewPrependable(totalLen)
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcp.SetSourcePort(srcPort)
+ tcp.SetDestinationPort(dstPort)
+ tcp.SetDataOffset(header.TCPMinimumSize)
+ tcp.SetChecksum(0)
+ tcp.SetChecksum(^tcp.CalculateChecksum(header.PseudoHeaderChecksum(
+ header.TCPProtocolNumber,
+ utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ uint16(len(tcp)),
+ )))
+ ipHdr(hdr.Prepend(header.IPv4MinimumSize),
+ totalLen,
+ header.TCPProtocolNumber,
+ utils.Host2IPv4Addr.AddressWithPrefix.Address,
+ utils.RouterNIC2IPv4Addr.AddressWithPrefix.Address,
+ )
+ return hdr.View()
+ }(),
+ checkNATed: func(t *testing.T, v buffer.View) {
+ checker.IPv4(t, v,
+ checker.SrcAddr(utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address),
+ checker.DstAddr(utils.Host1IPv4Addr.AddressWithPrefix.Address),
+ checker.TCP(
+ checker.SrcPort(srcPort),
+ checker.DstPort(dstPort),
+ ),
+ )
+ },
+ },
+ },
+ icmpTypes: []icmpTypeTest{
+ {
+ name: "Destination Unreachable",
+ val: uint8(header.ICMPv4DstUnreachable),
+ expectResponse: true,
+ },
+ {
+ name: "Time Exceeded",
+ val: uint8(header.ICMPv4TimeExceeded),
+ expectResponse: true,
+ },
+ {
+ name: "Parameter Problem",
+ val: uint8(header.ICMPv4ParamProblem),
+ expectResponse: true,
+ },
+ {
+ name: "Echo Request",
+ val: uint8(header.ICMPv4Echo),
+ expectResponse: false,
+ },
+ {
+ name: "Echo Reply",
+ val: uint8(header.ICMPv4EchoReply),
+ expectResponse: false,
+ },
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ for _, transportType := range test.transportTypes {
+ t.Run(transportType.name, func(t *testing.T) {
+ for _, icmpType := range test.icmpTypes {
+ t.Run(icmpType.name, func(t *testing.T) {
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
+ TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol},
+ })
+
+ ep1 := channel.New(1, header.IPv6MinimumMTU, "")
+ ep2 := channel.New(1, header.IPv6MinimumMTU, "")
+ utils.SetupRouterStack(t, s, ep1, ep2)
+
+ ipv6 := test.netProto == ipv6.ProtocolNumber
+ ipt := s.IPTables()
+
+ table := stack.Table{
+ Rules: []stack.Rule{
+ // Prerouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ InputInterface: utils.RouterNIC2Name,
+ },
+ Target: &stack.DNATTarget{NetworkProtocol: test.netProto, Addr: test.host1Addr, Port: dstPort},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Input
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Forward
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Output
+ {
+ Target: &stack.AcceptTarget{},
+ },
+
+ // Postrouting
+ {
+ Filter: stack.IPHeaderFilter{
+ Protocol: transportType.proto,
+ CheckProtocol: true,
+ OutputInterface: utils.RouterNIC1Name,
+ },
+ Target: &stack.MasqueradeTarget{NetworkProtocol: test.netProto},
+ },
+ {
+ Target: &stack.AcceptTarget{},
+ },
+ },
+ BuiltinChains: [stack.NumHooks]int{
+ stack.Prerouting: 0,
+ stack.Input: 2,
+ stack.Forward: 3,
+ stack.Output: 4,
+ stack.Postrouting: 5,
+ },
+ }
+
+ if err := ipt.ReplaceTable(stack.NATID, table, ipv6); err != nil {
+ t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, ipv6, err)
+ }
+
+ ep2.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: append(buffer.View(nil), transportType.buf...).ToVectorisedView(),
+ }))
+
+ {
+ pkt, ok := ep1.Read()
+ if !ok {
+ t.Fatal("expected to read a packet on ep1")
+ }
+ pktView := stack.PayloadSince(pkt.Pkt.NetworkHeader())
+ transportType.checkNATed(t, pktView)
+ if t.Failed() {
+ t.FailNow()
+ }
+
+ ep1.InjectInbound(test.netProto, stack.NewPacketBuffer(stack.PacketBufferOptions{
+ Data: test.icmpError(t, pktView, icmpType.val).ToVectorisedView(),
+ }))
+ }
+
+ pkt, ok := ep2.Read()
+ if ok != icmpType.expectResponse {
+ t.Fatalf("got ep2.Read() = (%#v, %t), want = (_, %t)", pkt, ok, icmpType.expectResponse)
+ }
+ if !icmpType.expectResponse {
+ return
+ }
+ test.decrementTTL(transportType.buf)
+ test.checkNATedError(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), transportType.buf, icmpType.val)
+ })
+ }
+ })
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go
index c69410859..654309584 100644
--- a/pkg/tcpip/tests/utils/utils.go
+++ b/pkg/tcpip/tests/utils/utils.go
@@ -213,12 +213,77 @@ func (e *EndpointWithDestinationCheck) DeliverNetworkPacket(src, dst tcpip.LinkA
}
}
+// SetupRouterStack creates the NICs, sets forwarding, adds addresses and sets
+// the route table for a stack that should operate as a router.
+func SetupRouterStack(t *testing.T, s *stack.Stack, ep1, ep2 stack.LinkEndpoint) {
+
+ if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err)
+ }
+ if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
+ t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
+ }
+
+ for _, setup := range []struct {
+ nicID tcpip.NICID
+ nicName string
+ ep stack.LinkEndpoint
+
+ addresses [2]tcpip.ProtocolAddress
+ }{
+ {
+ nicID: RouterNICID1,
+ nicName: RouterNIC1Name,
+ ep: ep1,
+ addresses: [2]tcpip.ProtocolAddress{RouterNIC1IPv4Addr, RouterNIC1IPv6Addr},
+ },
+ {
+ nicID: RouterNICID2,
+ nicName: RouterNIC2Name,
+ ep: ep2,
+ addresses: [2]tcpip.ProtocolAddress{RouterNIC2IPv4Addr, RouterNIC2IPv6Addr},
+ },
+ } {
+ opts := stack.NICOptions{Name: setup.nicName}
+ if err := s.CreateNICWithOptions(setup.nicID, setup.ep, opts); err != nil {
+ t.Fatalf("s.CreateNICWithOptions(%d, _, %#v): %s", setup.nicID, opts, err)
+ }
+
+ for _, addr := range setup.addresses {
+ if err := s.AddProtocolAddress(setup.nicID, addr, stack.AddressProperties{}); err != nil {
+ t.Fatalf("s.AddProtocolAddress(%d, %#v, {}): %s", setup.nicID, addr, err)
+ }
+ }
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: RouterNIC1IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: RouterNICID1,
+ },
+ {
+ Destination: RouterNIC1IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: RouterNICID1,
+ },
+ {
+ Destination: RouterNIC2IPv4Addr.AddressWithPrefix.Subnet(),
+ NIC: RouterNICID2,
+ },
+ {
+ Destination: RouterNIC2IPv6Addr.AddressWithPrefix.Subnet(),
+ NIC: RouterNICID2,
+ },
+ })
+}
+
// SetupRoutedStacks creates the NICs, sets forwarding, adds addresses and sets
// the route tables for the passed stacks.
func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) {
host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2)
routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4)
+ SetupRouterStack(t, routerStack, NewEthernetEndpoint(routerNIC1), NewEthernetEndpoint(routerNIC2))
+
{
opts := stack.NICOptions{Name: Host1NICName}
if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil {
@@ -226,52 +291,21 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
}
}
{
- opts := stack.NICOptions{Name: RouterNIC1Name}
- if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil {
- t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err)
- }
- }
- {
- opts := stack.NICOptions{Name: RouterNIC2Name}
- if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil {
- t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err)
- }
- }
- {
opts := stack.NICOptions{Name: Host2NICName}
if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil {
t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err)
}
}
- if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv4.ProtocolNumber, err)
- }
- if err := routerStack.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil {
- t.Fatalf("routerStack.SetForwardingDefaultAndAllNICs(%d): %s", ipv6.ProtocolNumber, err)
- }
-
if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv4Addr, stack.AddressProperties{}); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv4Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv4Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv4Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv4Addr, err)
- }
if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv4Addr, stack.AddressProperties{}); err != nil {
t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv4Addr, err)
}
if err := host1Stack.AddProtocolAddress(Host1NICID, Host1IPv6Addr, stack.AddressProperties{}); err != nil {
t.Fatalf("host1Stack.AddProtocolAddress(%d, %+v, {}): %s", Host1NICID, Host1IPv6Addr, err)
}
- if err := routerStack.AddProtocolAddress(RouterNICID1, RouterNIC1IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID1, RouterNIC1IPv6Addr, err)
- }
- if err := routerStack.AddProtocolAddress(RouterNICID2, RouterNIC2IPv6Addr, stack.AddressProperties{}); err != nil {
- t.Fatalf("routerStack.AddProtocolAddress(%d, %+v, {}): %s", RouterNICID2, RouterNIC2IPv6Addr, err)
- }
if err := host2Stack.AddProtocolAddress(Host2NICID, Host2IPv6Addr, stack.AddressProperties{}); err != nil {
t.Fatalf("host2Stack.AddProtocolAddress(%d, %+v, {}): %s", Host2NICID, Host2IPv6Addr, err)
}
@@ -296,24 +330,6 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack.
NIC: Host1NICID,
},
})
- routerStack.SetRouteTable([]tcpip.Route{
- {
- Destination: RouterNIC1IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID1,
- },
- {
- Destination: RouterNIC1IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID1,
- },
- {
- Destination: RouterNIC2IPv4Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID2,
- },
- {
- Destination: RouterNIC2IPv6Addr.AddressWithPrefix.Subnet(),
- NIC: RouterNICID2,
- },
- })
host2Stack.SetRouteTable([]tcpip.Route{
{
Destination: Host2IPv4Addr.AddressWithPrefix.Subnet(),