diff options
Diffstat (limited to 'test/iptables/nat.go')
-rw-r--r-- | test/iptables/nat.go | 225 |
1 files changed, 119 insertions, 106 deletions
diff --git a/test/iptables/nat.go b/test/iptables/nat.go index b7fea2527..dd9a18339 100644 --- a/test/iptables/nat.go +++ b/test/iptables/nat.go @@ -15,11 +15,11 @@ package iptables import ( + "context" "errors" "fmt" "net" "syscall" - "time" ) const redirectPort = 42 @@ -46,7 +46,7 @@ func init() { } // NATPreRedirectUDPPort tests that packets are redirected to different port. -type NATPreRedirectUDPPort struct{} +type NATPreRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectUDPPort) Name() string { @@ -54,12 +54,12 @@ func (NATPreRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(redirectPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, redirectPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", redirectPort, err) } @@ -67,12 +67,12 @@ func (NATPreRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectTCPPort tests that connections are redirected on specified ports. -type NATPreRedirectTCPPort struct{} +type NATPreRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPPort) Name() string { @@ -80,23 +80,23 @@ func (NATPreRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } // Listen for TCP packets on redirect port. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATPreRedirectTCPOutgoing verifies that outgoing TCP connections aren't // affected by PREROUTING connection tracking. -type NATPreRedirectTCPOutgoing struct{} +type NATPreRedirectTCPOutgoing struct{ baseCase } // Name implements TestCase.Name. func (NATPreRedirectTCPOutgoing) Name() string { @@ -104,24 +104,24 @@ func (NATPreRedirectTCPOutgoing) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectTCPOutgoing) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectTCPOutgoing) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all incoming TCP traffic to a closed port. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return connectTCP(ip, acceptPort, sendloopDuration) + return connectTCP(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectTCPOutgoing) LocalAction(ip net.IP, ipv6 bool) error { - return listenTCP(acceptPort, sendloopDuration) +func (NATPreRedirectTCPOutgoing) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenTCP(ctx, acceptPort) } // NATOutRedirectTCPIncoming verifies that incoming TCP connections aren't // affected by OUTPUT connection tracking. -type NATOutRedirectTCPIncoming struct{} +type NATOutRedirectTCPIncoming struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPIncoming) Name() string { @@ -129,23 +129,23 @@ func (NATOutRedirectTCPIncoming) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPIncoming) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPIncoming) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect all outgoing TCP traffic to a closed port. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } // Establish a connection to the host process. - return listenTCP(acceptPort, sendloopDuration) + return listenTCP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPIncoming) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, acceptPort, sendloopDuration) +func (NATOutRedirectTCPIncoming) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, acceptPort) } // NATOutRedirectUDPPort tests that packets are redirected to different port. -type NATOutRedirectUDPPort struct{} +type NATOutRedirectUDPPort struct{ containerCase } // Name implements TestCase.Name. func (NATOutRedirectUDPPort) Name() string { @@ -153,19 +153,19 @@ func (NATOutRedirectUDPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectUDPPort) ContainerAction(ip net.IP, ipv6 bool) error { - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) +func (NATOutRedirectUDPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)) } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectUDPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectUDPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATDropUDP tests that packets are not received in ports other than redirect // port. -type NATDropUDP struct{} +type NATDropUDP struct{ containerCase } // Name implements TestCase.Name. func (NATDropUDP) Name() string { @@ -173,25 +173,29 @@ func (NATDropUDP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATDropUDP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATDropUDP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", redirectPort)); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err == nil { + timedCtx, cancel := context.WithTimeout(ctx, NegativeTimeout) + defer cancel() + if err := listenUDP(timedCtx, acceptPort); err == nil { return fmt.Errorf("packets on port %d should have been redirected to port %d", acceptPort, redirectPort) + } else if !errors.Is(err, context.DeadlineExceeded) { + return fmt.Errorf("error reading: %v", err) } return nil } // LocalAction implements TestCase.LocalAction. -func (NATDropUDP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATDropUDP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATAcceptAll tests that all UDP packets are accepted. -type NATAcceptAll struct{} +type NATAcceptAll struct{ containerCase } // Name implements TestCase.Name. func (NATAcceptAll) Name() string { @@ -199,12 +203,12 @@ func (NATAcceptAll) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATAcceptAll) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-j", "ACCEPT"); err != nil { return err } - if err := listenUDP(acceptPort, sendloopDuration); err != nil { + if err := listenUDP(ctx, acceptPort); err != nil { return fmt.Errorf("packets on port %d should be allowed, but encountered an error: %v", acceptPort, err) } @@ -212,13 +216,13 @@ func (NATAcceptAll) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATAcceptAll) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATAcceptAll) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATOutRedirectIP uses iptables to select packets based on destination IP and // redirects them. -type NATOutRedirectIP struct{} +type NATOutRedirectIP struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectIP) Name() string { @@ -226,9 +230,9 @@ func (NATOutRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "-d", nowhereIP(ipv6), "-p", "udp", @@ -236,14 +240,14 @@ func (NATOutRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATOutDontRedirectIP struct{} +type NATOutDontRedirectIP struct{ localCase } // Name implements TestCase.Name. func (NATOutDontRedirectIP) Name() string { @@ -251,20 +255,20 @@ func (NATOutDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-d", localIP(ipv6), "-p", "udp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return sendUDPLoop(ip, acceptPort, sendloopDuration) + return sendUDPLoop(ctx, ip, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATOutDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return listenUDP(acceptPort, sendloopDuration) +func (NATOutDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return listenUDP(ctx, acceptPort) } // NATOutRedirectInvert tests that iptables can match with "! -d". -type NATOutRedirectInvert struct{} +type NATOutRedirectInvert struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectInvert) Name() string { @@ -272,13 +276,13 @@ func (NATOutRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect OUTPUT packets to a listening localhost port. dest := "192.0.2.2" if ipv6 { dest = "2001:db8::2" } - return loopbackTest(ipv6, net.ParseIP(nowhereIP(ipv6)), + return loopbackTest(ctx, ipv6, net.ParseIP(nowhereIP(ipv6)), "-A", "OUTPUT", "!", "-d", dest, "-p", "udp", @@ -286,14 +290,14 @@ func (NATOutRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreRedirectIP tests that we can use iptables to select packets based on // destination IP and redirect them. -type NATPreRedirectIP struct{} +type NATPreRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectIP) Name() string { @@ -301,7 +305,7 @@ func (NATPreRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { addrs, err := localAddrs(ipv6) if err != nil { return err @@ -314,17 +318,17 @@ func (NATPreRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { if err := natTableRules(ipv6, rules); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATPreDontRedirectIP tests that iptables matching with "-d" does not match // packets it shouldn't. -type NATPreDontRedirectIP struct{} +type NATPreDontRedirectIP struct{ containerCase } // Name implements TestCase.Name. func (NATPreDontRedirectIP) Name() string { @@ -332,20 +336,20 @@ func (NATPreDontRedirectIP) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreDontRedirectIP) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreDontRedirectIP) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", dropPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreDontRedirectIP) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, acceptPort, sendloopDuration) +func (NATPreDontRedirectIP) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, acceptPort) } // NATPreRedirectInvert tests that iptables can match with "! -d". -type NATPreRedirectInvert struct{} +type NATPreRedirectInvert struct{ containerCase } // Name implements TestCase.Name. func (NATPreRedirectInvert) Name() string { @@ -353,21 +357,21 @@ func (NATPreRedirectInvert) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreRedirectInvert) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreRedirectInvert) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-p", "udp", "!", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - return listenUDP(acceptPort, sendloopDuration) + return listenUDP(ctx, acceptPort) } // LocalAction implements TestCase.LocalAction. -func (NATPreRedirectInvert) LocalAction(ip net.IP, ipv6 bool) error { - return spawnUDPLoop(ip, dropPort, sendloopDuration) +func (NATPreRedirectInvert) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return sendUDPLoop(ctx, ip, dropPort) } // NATRedirectRequiresProtocol tests that use of the --to-ports flag requires a // protocol to be specified with -p. -type NATRedirectRequiresProtocol struct{} +type NATRedirectRequiresProtocol struct{ baseCase } // Name implements TestCase.Name. func (NATRedirectRequiresProtocol) Name() string { @@ -375,7 +379,7 @@ func (NATRedirectRequiresProtocol) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "PREROUTING", "-d", localIP(ipv6), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err == nil { return errors.New("expected an error using REDIRECT --to-ports without a protocol") } @@ -383,13 +387,13 @@ func (NATRedirectRequiresProtocol) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATRedirectRequiresProtocol) LocalAction(ip net.IP, ipv6 bool) error { +func (NATRedirectRequiresProtocol) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATOutRedirectTCPPort tests that connections are redirected on specified ports. -type NATOutRedirectTCPPort struct{} +type NATOutRedirectTCPPort struct{ baseCase } // Name implements TestCase.Name. func (NATOutRedirectTCPPort) Name() string { @@ -397,12 +401,11 @@ func (NATOutRedirectTCPPort) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-m", "tcp", "--dport", fmt.Sprintf("%d", dropPort), "-j", "REDIRECT", "--to-ports", fmt.Sprintf("%d", acceptPort)); err != nil { return err } - timeout := 20 * time.Second localAddr := net.TCPAddr{ IP: net.ParseIP(localIP(ipv6)), Port: acceptPort, @@ -416,9 +419,7 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { defer lConn.Close() // Accept connections on port. - lConn.SetDeadline(time.Now().Add(timeout)) - err = connectTCP(ip, dropPort, timeout) - if err != nil { + if err := connectTCP(ctx, ip, dropPort); err != nil { return err } @@ -432,13 +433,13 @@ func (NATOutRedirectTCPPort) ContainerAction(ip net.IP, ipv6 bool) error { } // LocalAction implements TestCase.LocalAction. -func (NATOutRedirectTCPPort) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutRedirectTCPPort) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { return nil } // NATLoopbackSkipsPrerouting tests that packets sent via loopback aren't // affected by PREROUTING rules. -type NATLoopbackSkipsPrerouting struct{} +type NATLoopbackSkipsPrerouting struct{ baseCase } // Name implements TestCase.Name. func (NATLoopbackSkipsPrerouting) Name() string { @@ -446,7 +447,7 @@ func (NATLoopbackSkipsPrerouting) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect anything sent to localhost to an unused port. dest := []byte{127, 0, 0, 1} if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", dropPort)); err != nil { @@ -457,24 +458,24 @@ func (NATLoopbackSkipsPrerouting) ContainerAction(ip net.IP, ipv6 bool) error { // loopback traffic, the connection would fail. sendCh := make(chan error) go func() { - sendCh <- connectTCP(dest, acceptPort, sendloopDuration) + sendCh <- connectTCP(ctx, dest, acceptPort) }() - if err := listenTCP(acceptPort, sendloopDuration); err != nil { + if err := listenTCP(ctx, acceptPort); err != nil { return err } return <-sendCh } // LocalAction implements TestCase.LocalAction. -func (NATLoopbackSkipsPrerouting) LocalAction(ip net.IP, ipv6 bool) error { +func (NATLoopbackSkipsPrerouting) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } // NATPreOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of PREROUTING NATted packets. -type NATPreOriginalDst struct{} +type NATPreOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATPreOriginalDst) Name() string { @@ -482,7 +483,7 @@ func (NATPreOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATPreOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "PREROUTING", "-p", "tcp", @@ -495,17 +496,17 @@ func (NATPreOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { if err != nil { return err } - return listenForRedirectedConn(ipv6, addrs) + return listenForRedirectedConn(ctx, ipv6, addrs) } // LocalAction implements TestCase.LocalAction. -func (NATPreOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { - return connectTCP(ip, dropPort, sendloopDuration) +func (NATPreOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { + return connectTCP(ctx, ip, dropPort) } // NATOutOriginalDst tests that SO_ORIGINAL_DST returns the pre-NAT destination // of OUTBOUND NATted packets. -type NATOutOriginalDst struct{} +type NATOutOriginalDst struct{ baseCase } // Name implements TestCase.Name. func (NATOutOriginalDst) Name() string { @@ -513,7 +514,7 @@ func (NATOutOriginalDst) Name() string { } // ContainerAction implements TestCase.ContainerAction. -func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) ContainerAction(ctx context.Context, ip net.IP, ipv6 bool) error { // Redirect incoming TCP connections to acceptPort. if err := natTable(ipv6, "-A", "OUTPUT", "-p", "tcp", "-j", "REDIRECT", "--to-port", fmt.Sprintf("%d", acceptPort)); err != nil { return err @@ -521,22 +522,22 @@ func (NATOutOriginalDst) ContainerAction(ip net.IP, ipv6 bool) error { connCh := make(chan error) go func() { - connCh <- connectTCP(ip, dropPort, sendloopDuration) + connCh <- connectTCP(ctx, ip, dropPort) }() - if err := listenForRedirectedConn(ipv6, []net.IP{ip}); err != nil { + if err := listenForRedirectedConn(ctx, ipv6, []net.IP{ip}); err != nil { return err } return <-connCh } // LocalAction implements TestCase.LocalAction. -func (NATOutOriginalDst) LocalAction(ip net.IP, ipv6 bool) error { +func (NATOutOriginalDst) LocalAction(ctx context.Context, ip net.IP, ipv6 bool) error { // No-op. return nil } -func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { +func listenForRedirectedConn(ctx context.Context, ipv6 bool, originalDsts []net.IP) error { // The net package doesn't give guarantee access to the connection's // underlying FD, and thus we cannot call getsockopt. We have to use // traditional syscalls for SO_ORIGINAL_DST. @@ -572,16 +573,32 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { return err } - connfd, _, err := syscall.Accept(sockfd) - if err != nil { + // Block on accept() in another goroutine. + connCh := make(chan int) + errCh := make(chan error) + go func() { + connFD, _, err := syscall.Accept(sockfd) + if err != nil { + errCh <- err + } + connCh <- connFD + }() + + // Wait for accept() to return or for the context to finish. + var connFD int + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errCh: return err + case connFD = <-connCh: } - defer syscall.Close(connfd) + defer syscall.Close(connFD) // Verify that, despite listening on acceptPort, SO_ORIGINAL_DST // indicates the packet was sent to originalDst:dropPort. if ipv6 { - got, err := originalDestination6(connfd) + got, err := originalDestination6(connFD) if err != nil { return err } @@ -598,7 +615,7 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { } return fmt.Errorf("SO_ORIGINAL_DST returned %+v, but wanted one of %+v (note: port numbers are in network byte order)", got, originalDsts) } else { - got, err := originalDestination4(connfd) + got, err := originalDestination4(connFD) if err != nil { return err } @@ -619,26 +636,22 @@ func listenForRedirectedConn(ipv6 bool, originalDsts []net.IP) error { // loopbackTests runs an iptables rule and ensures that packets sent to // dest:dropPort are received by localhost:acceptPort. -func loopbackTest(ipv6 bool, dest net.IP, args ...string) error { +func loopbackTest(ctx context.Context, ipv6 bool, dest net.IP, args ...string) error { if err := natTable(ipv6, args...); err != nil { return err } - sendCh := make(chan error) - listenCh := make(chan error) + sendCh := make(chan error, 1) + listenCh := make(chan error, 1) go func() { - sendCh <- sendUDPLoop(dest, dropPort, sendloopDuration) + sendCh <- sendUDPLoop(ctx, dest, dropPort) }() go func() { - listenCh <- listenUDP(acceptPort, sendloopDuration) + listenCh <- listenUDP(ctx, acceptPort) }() select { case err := <-listenCh: - if err != nil { - return err - } - case <-time.After(sendloopDuration): - return errors.New("timed out") + return err + case err := <-sendCh: + return err } - // sendCh will always take the full sendloop time. - return <-sendCh } |