diff options
43 files changed, 1449 insertions, 668 deletions
diff --git a/g3doc/user_guide/install.md b/g3doc/user_guide/install.md index 85ba6a161..73e953905 100644 --- a/g3doc/user_guide/install.md +++ b/g3doc/user_guide/install.md @@ -51,18 +51,17 @@ sudo apt-get install -y \ apt-transport-https \ ca-certificates \ curl \ - gnupg-agent \ - software-properties-common + gnupg ``` Next, configure the key used to sign archives and the repository. NOTE: The key was updated on 2021-07-13 to replace the expired key. If you get -errors about the key being expired, run the `apt-key add` command below again. +errors about the key being expired, run the `curl` command below again. ```bash -curl -fsSL https://gvisor.dev/archive.key | sudo apt-key add - -sudo add-apt-repository "deb [arch=amd64,arm64] https://storage.googleapis.com/gvisor/releases release main" +curl -fsSL https://gvisor.dev/archive.key | sudo gpg --dearmor -o /usr/share/keyrings/gvisor-archive-keyring.gpg +echo "deb [arch=$(dpkg --print-architecture) signed-by=/usr/share/keyrings/gvisor-archive-keyring.gpg] https://storage.googleapis.com/gvisor/releases release main" | sudo tee /etc/apt/sources.list.d/gvisor.list > /dev/null ``` Now the runsc package can be installed: diff --git a/pkg/crypto/crypto_stdlib.go b/pkg/crypto/crypto_stdlib.go index 69e867386..28eba2ff6 100644 --- a/pkg/crypto/crypto_stdlib.go +++ b/pkg/crypto/crypto_stdlib.go @@ -19,14 +19,21 @@ package crypto import ( "crypto/ecdsa" + "crypto/elliptic" "crypto/sha512" + "fmt" "math/big" ) -// EcdsaVerify verifies the signature in r, s of hash using ECDSA and the -// public key, pub. Its return value records whether the signature is valid. -func EcdsaVerify(pub *ecdsa.PublicKey, hash []byte, r, s *big.Int) (bool, error) { - return ecdsa.Verify(pub, hash, r, s), nil +// EcdsaP384Sha384Verify verifies the signature in r, s of hash using ECDSA +// P384 + SHA 384 and the public key, pub. Its return value records whether +// the signature is valid. +func EcdsaP384Sha384Verify(pub *ecdsa.PublicKey, data []byte, r, s *big.Int) (bool, error) { + if pub.Curve != elliptic.P384() { + return false, fmt.Errorf("unsupported key curve: want P-384, got %v", pub.Curve) + } + digest := sha512.Sum384(data) + return ecdsa.Verify(pub, digest[:], r, s), nil } // SumSha384 returns the SHA384 checksum of the data. diff --git a/pkg/sentry/platform/kvm/BUILD b/pkg/sentry/platform/kvm/BUILD index 9aba92e91..a26f54269 100644 --- a/pkg/sentry/platform/kvm/BUILD +++ b/pkg/sentry/platform/kvm/BUILD @@ -87,6 +87,12 @@ go_test( "virtual_map_test.go", ], library = ":kvm", + # FIXME(gvisor.dev/issue/3374): Not working with all build systems. + nogo = False, + # cgo has to be disabled. We have seen libc that blocks all signals and + # calls mmap from pthread_create, but we use SIGSYS to trap mmap system + # calls. + pure = True, tags = [ "manual", "nogotsan", @@ -106,12 +112,6 @@ go_test( "//pkg/sentry/time", "@org_golang_x_sys//unix:go_default_library", ], - # FIXME(gvisor.dev/issue/3374): Not working with all build systems. - nogo = False, - # cgo has to be disabled. We have seen libc that blocks all signals and - # calls mmap from pthread_create, but we use SIGSYS to trap mmap system - # calls. - pure = True, ) genrule( diff --git a/pkg/sentry/platform/kvm/kvm_safecopy_test.go b/pkg/sentry/platform/kvm/kvm_safecopy_test.go index cbfc61919..9a87c9e6f 100644 --- a/pkg/sentry/platform/kvm/kvm_safecopy_test.go +++ b/pkg/sentry/platform/kvm/kvm_safecopy_test.go @@ -102,4 +102,3 @@ func TestSafecopy(t *testing.T) { } }) } - diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index dce6960f9..dcf34015d 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -279,7 +279,7 @@ func newMachine(vm int) (*machine, error) { } } - for _, vr := range(readOnlyGuestRegions) { + for _, vr := range readOnlyGuestRegions { mapRegion(vr, _KVM_MEM_READONLY) } diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index 52e0a2f37..ab1e036b7 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -502,6 +502,6 @@ func (m *machine) getNewVCPU() *vCPU { return nil } -func archPhysicalRegions(physicalRegions []physicalRegion) ([]physicalRegion) { +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { return physicalRegions } diff --git a/pkg/sentry/platform/kvm/machine_arm64.go b/pkg/sentry/platform/kvm/machine_arm64.go index 2a25a757f..08d98c479 100644 --- a/pkg/sentry/platform/kvm/machine_arm64.go +++ b/pkg/sentry/platform/kvm/machine_arm64.go @@ -112,7 +112,7 @@ func rdonlyRegionsForSetMem() (phyRegions []physicalRegion) { // archPhysicalRegions fills readOnlyGuestRegions and allocates separate // physical regions form them. -func archPhysicalRegions(physicalRegions []physicalRegion) ([]physicalRegion) { +func archPhysicalRegions(physicalRegions []physicalRegion) []physicalRegion { applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) { return // skip region. @@ -125,26 +125,26 @@ func archPhysicalRegions(physicalRegions []physicalRegion) ([]physicalRegion) { rdRegions := readOnlyGuestRegions[:] // Add an unreachable region. - rdRegions = append(rdRegions, region { + rdRegions = append(rdRegions, region{ virtual: 0xffffffffffffffff, - length: 0, + length: 0, }) - var regions []physicalRegion + var regions []physicalRegion addValidRegion := func(r *physicalRegion, virtual, length uintptr) { if length == 0 { return } - regions = append(regions, physicalRegion { - region: region{ - virtual: virtual, - length: length, - }, - physical: r.physical + (virtual - r.virtual), + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), }) } i := 0 - for _, pr := range(physicalRegions) { + for _, pr := range physicalRegions { start := pr.virtual end := pr.virtual + pr.length for start < end { @@ -160,16 +160,16 @@ func archPhysicalRegions(physicalRegions []physicalRegion) ([]physicalRegion) { if end < rdStart { newEnd = end } - addValidRegion(&pr, start, newEnd - start) + addValidRegion(&pr, start, newEnd-start) start = rdStart continue } if rdEnd < end { - addValidRegion(&pr, start, rdEnd - start) + addValidRegion(&pr, start, rdEnd-start) start = rdEnd continue } - addValidRegion(&pr, start, end - start) + addValidRegion(&pr, start, end-start) start = end } } @@ -178,7 +178,7 @@ func archPhysicalRegions(physicalRegions []physicalRegion) ([]physicalRegion) { } // Get all available physicalRegions. -func availableRegionsForSetMem() ([]physicalRegion) { +func availableRegionsForSetMem() []physicalRegion { var excludedRegions []region applyVirtualRegions(func(vr virtualRegion) { if !vr.accessType.Write { @@ -187,9 +187,9 @@ func availableRegionsForSetMem() ([]physicalRegion) { }) // Add an unreachable region. - excludedRegions = append(excludedRegions, region { + excludedRegions = append(excludedRegions, region{ virtual: 0xffffffffffffffff, - length: 0, + length: 0, }) var regions []physicalRegion @@ -197,16 +197,16 @@ func availableRegionsForSetMem() ([]physicalRegion) { if length == 0 { return } - regions = append(regions, physicalRegion { - region: region{ - virtual: virtual, - length: length, - }, - physical: r.physical + (virtual - r.virtual), + regions = append(regions, physicalRegion{ + region: region{ + virtual: virtual, + length: length, + }, + physical: r.physical + (virtual - r.virtual), }) } i := 0 - for _, pr := range(physicalRegions) { + for _, pr := range physicalRegions { start := pr.virtual end := pr.virtual + pr.length for start < end { @@ -226,7 +226,7 @@ func availableRegionsForSetMem() ([]physicalRegion) { if rend > end { rend = end } - addValidRegion(&pr, start, rend - start) + addValidRegion(&pr, start, rend-start) start = excludeEnd } } diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD index 943fa180d..35feb969f 100644 --- a/pkg/sentry/seccheck/BUILD +++ b/pkg/sentry/seccheck/BUILD @@ -8,6 +8,8 @@ go_fieldenum( name = "seccheck_fieldenum", srcs = [ "clone.go", + "execve.go", + "exit.go", "task.go", ], out = "seccheck_fieldenum.go", @@ -29,6 +31,8 @@ go_library( name = "seccheck", srcs = [ "clone.go", + "execve.go", + "exit.go", "seccheck.go", "seccheck_fieldenum.go", "seqatomic_checkerslice_unsafe.go", diff --git a/pkg/sentry/seccheck/execve.go b/pkg/sentry/seccheck/execve.go new file mode 100644 index 000000000..f36e0730e --- /dev/null +++ b/pkg/sentry/seccheck/execve.go @@ -0,0 +1,65 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" +) + +// ExecveInfo contains information used by the Execve checkpoint. +// +// +fieldenum Execve +type ExecveInfo struct { + // Invoker identifies the invoking thread. + Invoker TaskInfo + + // Credentials are the invoking thread's credentials. + Credentials *auth.Credentials + + // BinaryPath is a path to the executable binary file being switched to in + // the mount namespace in which it was opened. + BinaryPath string + + // Argv is the new process image's argument vector. + Argv []string + + // Env is the new process image's environment variables. + Env []string + + // BinaryMode is the executable binary file's mode. + BinaryMode uint16 + + // BinarySHA256 is the SHA-256 hash of the executable binary file. + // + // Note that this requires reading the entire file into memory, which is + // likely to be extremely slow. + BinarySHA256 [32]byte +} + +// ExecveReq returns fields required by the Execve checkpoint. +func (s *state) ExecveReq() ExecveFieldSet { + return s.execveReq.Load() +} + +// Execve is called at the Execve checkpoint. +func (s *state) Execve(ctx context.Context, mask ExecveFieldSet, info *ExecveInfo) error { + for _, c := range s.getCheckers() { + if err := c.Execve(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/exit.go b/pkg/sentry/seccheck/exit.go new file mode 100644 index 000000000..69cb6911c --- /dev/null +++ b/pkg/sentry/seccheck/exit.go @@ -0,0 +1,57 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package seccheck + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" +) + +// ExitNotifyParentInfo contains information used by the ExitNotifyParent +// checkpoint. +// +// +fieldenum ExitNotifyParent +type ExitNotifyParentInfo struct { + // Exiter identifies the exiting thread. Note that by the checkpoint's + // definition, Exiter.ThreadID == Exiter.ThreadGroupID and + // Exiter.ThreadStartTime == Exiter.ThreadGroupStartTime, so requesting + // ThreadGroup* fields is redundant. + Exiter TaskInfo + + // ExitStatus is the exiting thread group's exit status, as reported + // by wait*(). + ExitStatus linux.WaitStatus +} + +// ExitNotifyParentReq returns fields required by the ExitNotifyParent +// checkpoint. +func (s *state) ExitNotifyParentReq() ExitNotifyParentFieldSet { + return s.exitNotifyParentReq.Load() +} + +// ExitNotifyParent is called at the ExitNotifyParent checkpoint. +// +// The ExitNotifyParent checkpoint occurs when a zombied thread group leader, +// not waiting for exit acknowledgement from a non-parent ptracer, becomes the +// last non-dead thread in its thread group and notifies its parent of its +// exiting. +func (s *state) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info *ExitNotifyParentInfo) error { + for _, c := range s.getCheckers() { + if err := c.ExitNotifyParent(ctx, mask, *info); err != nil { + return err + } + } + return nil +} diff --git a/pkg/sentry/seccheck/seccheck.go b/pkg/sentry/seccheck/seccheck.go index b6c9d44ce..e13274096 100644 --- a/pkg/sentry/seccheck/seccheck.go +++ b/pkg/sentry/seccheck/seccheck.go @@ -29,6 +29,8 @@ type Point uint // PointX represents the checkpoint X. const ( PointClone Point = iota + PointExecve + PointExitNotifyParent // Add new Points above this line. pointLength @@ -47,6 +49,8 @@ const ( // registered concurrently with invocations of checkpoints). type Checker interface { Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error + Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error + ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error } // CheckerDefaults may be embedded by implementations of Checker to obtain @@ -58,6 +62,16 @@ func (CheckerDefaults) Clone(ctx context.Context, mask CloneFieldSet, info Clone return nil } +// Execve implements Checker.Execve. +func (CheckerDefaults) Execve(ctx context.Context, mask ExecveFieldSet, info ExecveInfo) error { + return nil +} + +// ExitNotifyParent implements Checker.ExitNotifyParent. +func (CheckerDefaults) ExitNotifyParent(ctx context.Context, mask ExitNotifyParentFieldSet, info ExitNotifyParentInfo) error { + return nil +} + // CheckerReq indicates what checkpoints a corresponding Checker runs at, and // what information it requires at those checkpoints. type CheckerReq struct { @@ -69,7 +83,9 @@ type CheckerReq struct { // All of the following fields indicate what fields in the corresponding // XInfo struct will be requested at the corresponding checkpoint. - Clone CloneFields + Clone CloneFields + Execve ExecveFields + ExitNotifyParent ExitNotifyParentFields } // Global is the method receiver of all seccheck functions. @@ -101,7 +117,9 @@ type state struct { // corresponding XInfo struct have been requested by any registered // checker, are accessed using atomic memory operations, and are mutated // with registrationMu locked. - cloneReq CloneFieldSet + cloneReq CloneFieldSet + execveReq ExecveFieldSet + exitNotifyParentReq ExitNotifyParentFieldSet } // AppendChecker registers the given Checker to execute at checkpoints. The @@ -110,7 +128,11 @@ type state struct { func (s *state) AppendChecker(c Checker, req *CheckerReq) { s.registrationMu.Lock() defer s.registrationMu.Unlock() + s.cloneReq.AddFieldsLoadable(req.Clone) + s.execveReq.AddFieldsLoadable(req.Execve) + s.exitNotifyParentReq.AddFieldsLoadable(req.ExitNotifyParent) + s.appendCheckerLocked(c) for _, p := range req.Points { word, bit := p/32, p%32 diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index ea56f39c1..0f6e576a9 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -647,7 +647,7 @@ func (jt *JumpTarget) id() targetID { } // Action implements stack.Target.Action. -func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { +func (jt *JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.Route, stack.AddressableEndpoint) (stack.RuleVerdict, int) { return stack.RuleJump, jt.RuleNum } diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 2aa38eb98..d51c36f19 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -240,12 +240,6 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { case header.ICMPv4Echo: received.echoRequest.Increment() - sent := e.stats.icmp.packetsSent - if !e.protocol.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return - } - // DeliverTransportPacket will take ownership of pkt so don't use it beyond // this point. Make a deep copy of the data before pkt gets sent as we will // be modifying fields. @@ -281,6 +275,12 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { } defer r.Release() + sent := e.stats.icmp.packetsSent + if !e.protocol.allowICMPReply(header.ICMPv4EchoReply, header.ICMPv4UnusedCode) { + sent.rateLimited.Increment() + return + } + // TODO(gvisor.dev/issue/3810:) When adding protocol numbers into the // header information, we may have to change this code to handle the // ICMP header no longer being in the data buffer. @@ -562,13 +562,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. @@ -606,6 +599,35 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, pointer := func() (header.ICMPv4Type, header.ICMPv4Code, tcpip.MultiCounterStat, byte) { + switch reason := reason.(type) { + case *icmpReasonPortUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonProtoUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4ProtoUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetworkUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4NetUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv4DstUnreachable, header.ICMPv4HostUnreachable, sent.dstUnreachable, 0 + case *icmpReasonFragmentationNeeded: + return header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, sent.dstUnreachable, 0 + case *icmpReasonTTLExceeded: + return header.ICMPv4TimeExceeded, header.ICMPv4TTLExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv4TimeExceeded, header.ICMPv4ReassemblyTimeout, sent.timeExceeded, 0 + case *icmpReasonParamProblem: + return header.ICMPv4ParamProblem, header.ICMPv4UnusedCode, sent.paramProblem, reason.pointer + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType, icmpCode) { + sent.rateLimited.Increment() + return nil + } + // Now work out how much of the triggering packet we should return. // As per RFC 1812 Section 4.3.2.3 // @@ -658,44 +680,9 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonProtoUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetworkUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4NetUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4HostUnreachable) - counter = sent.dstUnreachable - case *icmpReasonFragmentationNeeded: - icmpHdr.SetType(header.ICMPv4DstUnreachable) - icmpHdr.SetCode(header.ICMPv4FragmentationNeeded) - counter = sent.dstUnreachable - case *icmpReasonTTLExceeded: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4TTLExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv4TimeExceeded) - icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) - counter = sent.timeExceeded - case *icmpReasonParamProblem: - icmpHdr.SetType(header.ICMPv4ParamProblem) - icmpHdr.SetCode(header.ICMPv4UnusedCode) - icmpHdr.SetPointer(reason.pointer) - counter = sent.paramProblem - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetCode(icmpCode) + icmpHdr.SetType(icmpType) + icmpHdr.SetPointer(pointer) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data().AsRange().Checksum())) if err := route.WritePacket( diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 63223bc92..dda473e48 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -426,7 +426,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -466,7 +466,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, headerIn // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -549,7 +549,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) // iptables filtering. All packets that reach here are locally // generated. - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -576,7 +576,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -717,7 +717,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -744,7 +744,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -841,7 +841,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -940,7 +940,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return @@ -1215,6 +1215,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv4Type]struct{} } // defaultTTL is the current default TTL for the protocol. Only the @@ -1330,6 +1333,23 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return ipHdr.TransportProtocol(), !ipHdr.More() && ipHdr.FragmentOffset() == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type and code may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv4Type, code header.ICMPv4Code) bool { + // Mimic linux and never rate limit for PMTU discovery. + // https://github.com/torvalds/linux/blob/9e9fb7655ed585da8f468e29221f0ba194a5f613/net/ipv4/icmp.c#L288 + if icmpType == header.ICMPv4DstUnreachable && code == header.ICMPv4FragmentationNeeded { + return true + } + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload mtu. func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, tcpip.Error) { @@ -1409,6 +1429,14 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { } p.fragmentation = fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) + // Set ICMP rate limiting to Linux defaults. + // See https://man7.org/linux/man-pages/man7/icmp.7.html. + p.mu.icmpRateLimitedTypes = map[header.ICMPv4Type]struct{}{ + header.ICMPv4DstUnreachable: struct{}{}, + header.ICMPv4SrcQuench: struct{}{}, + header.ICMPv4TimeExceeded: struct{}{}, + header.ICMPv4ParamProblem: struct{}{}, + } return p } } diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index e7b5b3ea2..ef91245d7 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -3373,3 +3373,139 @@ func TestCloseLocking(t *testing.T) { } }() } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), + PrefixLen: 24, + }, + } + host2IPv4Addr = tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), + PrefixLen: 24, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv4Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv4Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.ICMPv4MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv4Echo) + icmpH.SetCode(header.ICMPv4UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(^header.Checksum(icmpH, 0)) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.ICMPv4ProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv4ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv4MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(totalLength), + Protocol: uint8(header.UDPProtocolNumber), + TTL: 1, + SrcAddr: host2IPv4Addr.AddressWithPrefix.Address, + DstAddr: host1IPv4Addr.AddressWithPrefix.Address, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv4(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv4Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv4Addr.AddressWithPrefix.Address), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv4ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index f99cbf8f3..f814926a3 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -51,6 +51,7 @@ go_test( "//pkg/tcpip/transport/udp", "//pkg/waiter", "@com_github_google_go_cmp//cmp:go_default_library", + "@org_golang_x_time//rate:go_default_library", ], ) diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index 94caaae6c..6c6107264 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -692,6 +692,11 @@ func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool, r } defer r.Release() + if !e.protocol.allowICMPReply(header.ICMPv6EchoReply) { + sent.rateLimited.Increment() + return + } + replyPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(r.MaxHeaderLength()) + header.ICMPv6EchoMinimumSize, Data: pkt.Data().ExtractVV(), @@ -1174,13 +1179,6 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip return &tcpip.ErrNotConnected{} } - sent := netEP.stats.icmp.packetsSent - - if !p.stack.AllowICMPMessage() { - sent.rateLimited.Increment() - return nil - } - if pkt.TransportProtocolNumber == header.ICMPv6ProtocolNumber { // TODO(gvisor.dev/issues/3810): Sort this out when ICMP headers are stored. // Unfortunately at this time ICMP Packets do not have a transport @@ -1198,6 +1196,33 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip } } + sent := netEP.stats.icmp.packetsSent + icmpType, icmpCode, counter, typeSpecific := func() (header.ICMPv6Type, header.ICMPv6Code, tcpip.MultiCounterStat, uint32) { + switch reason := reason.(type) { + case *icmpReasonParameterProblem: + return header.ICMPv6ParamProblem, reason.code, sent.paramProblem, reason.pointer + case *icmpReasonPortUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, sent.dstUnreachable, 0 + case *icmpReasonNetUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6NetworkUnreachable, sent.dstUnreachable, 0 + case *icmpReasonHostUnreachable: + return header.ICMPv6DstUnreachable, header.ICMPv6AddressUnreachable, sent.dstUnreachable, 0 + case *icmpReasonPacketTooBig: + return header.ICMPv6PacketTooBig, header.ICMPv6UnusedCode, sent.packetTooBig, 0 + case *icmpReasonHopLimitExceeded: + return header.ICMPv6TimeExceeded, header.ICMPv6HopLimitExceeded, sent.timeExceeded, 0 + case *icmpReasonReassemblyTimeout: + return header.ICMPv6TimeExceeded, header.ICMPv6ReassemblyTimeout, sent.timeExceeded, 0 + default: + panic(fmt.Sprintf("unsupported ICMP type %T", reason)) + } + }() + + if !p.allowICMPReply(icmpType) { + sent.rateLimited.Increment() + return nil + } + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() // As per RFC 4443 section 2.4 @@ -1232,40 +1257,10 @@ func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) tcpip newPkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber icmpHdr := header.ICMPv6(newPkt.TransportHeader().Push(header.ICMPv6DstUnreachableMinimumSize)) - var counter tcpip.MultiCounterStat - switch reason := reason.(type) { - case *icmpReasonParameterProblem: - icmpHdr.SetType(header.ICMPv6ParamProblem) - icmpHdr.SetCode(reason.code) - icmpHdr.SetTypeSpecific(reason.pointer) - counter = sent.paramProblem - case *icmpReasonPortUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6PortUnreachable) - counter = sent.dstUnreachable - case *icmpReasonNetUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6NetworkUnreachable) - counter = sent.dstUnreachable - case *icmpReasonHostUnreachable: - icmpHdr.SetType(header.ICMPv6DstUnreachable) - icmpHdr.SetCode(header.ICMPv6AddressUnreachable) - counter = sent.dstUnreachable - case *icmpReasonPacketTooBig: - icmpHdr.SetType(header.ICMPv6PacketTooBig) - icmpHdr.SetCode(header.ICMPv6UnusedCode) - counter = sent.packetTooBig - case *icmpReasonHopLimitExceeded: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6HopLimitExceeded) - counter = sent.timeExceeded - case *icmpReasonReassemblyTimeout: - icmpHdr.SetType(header.ICMPv6TimeExceeded) - icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) - counter = sent.timeExceeded - default: - panic(fmt.Sprintf("unsupported ICMP type %T", reason)) - } + icmpHdr.SetType(icmpType) + icmpHdr.SetCode(icmpCode) + icmpHdr.SetTypeSpecific(typeSpecific) + dataRange := newPkt.Data().AsRange() icmpHdr.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpHdr, diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 3b4c235fa..03d9f425c 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -22,6 +22,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -1435,6 +1436,8 @@ func TestPacketQueing(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, Clock: clock, }) + // Make sure ICMP rate limiting doesn't get in our way. + s.SetICMPLimit(rate.Inf) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 3799d8245..e2d2cf907 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -748,7 +748,7 @@ func (e *endpoint) WritePacket(r *stack.Route, params stack.NetworkHeaderParams, // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckOutput(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesOutputDropped.Increment() return nil @@ -788,7 +788,7 @@ func (e *endpoint) writePacket(r *stack.Route, pkt *stack.PacketBuffer, protocol // Postrouting NAT can only change the source address, and does not alter the // route or outgoing interface of the packet. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Postrouting, pkt, r, "" /* preroutingAddr */, "" /* inNicName */, outNicName); !ok { + if ok := e.protocol.stack.IPTables().CheckPostrouting(pkt, r, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesPostroutingDropped.Increment() return nil @@ -871,7 +871,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // iptables filtering. All packets that reach here are locally // generated. outNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - outputDropped, natPkts := e.protocol.stack.IPTables().CheckPackets(stack.Output, pkts, r, "" /* inNicName */, outNicName) + outputDropped, natPkts := e.protocol.stack.IPTables().CheckOutputPackets(pkts, r, outNicName) stats.IPTablesOutputDropped.IncrementBy(uint64(len(outputDropped))) for pkt := range outputDropped { pkts.Remove(pkt) @@ -897,7 +897,7 @@ func (e *endpoint) WritePackets(r *stack.Route, pkts stack.PacketBufferList, par // We ignore the list of NAT-ed packets here because Postrouting NAT can only // change the source address, and does not alter the route or outgoing // interface of the packet. - postroutingDropped, _ := e.protocol.stack.IPTables().CheckPackets(stack.Postrouting, pkts, r, "" /* inNicName */, outNicName) + postroutingDropped, _ := e.protocol.stack.IPTables().CheckPostroutingPackets(pkts, r, outNicName) stats.IPTablesPostroutingDropped.IncrementBy(uint64(len(postroutingDropped))) for pkt := range postroutingDropped { pkts.Remove(pkt) @@ -984,7 +984,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(ep.nic.ID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1015,7 +1015,7 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { inNicName := stk.FindNICNameFromID(e.nic.ID()) outNicName := stk.FindNICNameFromID(r.NICID()) - if ok := stk.IPTables().Check(stack.Forward, pkt, nil, "" /* preroutingAddr */, inNicName, outNicName); !ok { + if ok := stk.IPTables().CheckForward(pkt, inNicName, outNicName); !ok { // iptables is telling us to drop the packet. e.stats.ip.IPTablesForwardDropped.Increment() return nil @@ -1100,7 +1100,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // Loopback traffic skips the prerouting chain. inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Prerouting, pkt, nil, e.MainAddress().Address, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckPrerouting(pkt, e, inNicName); !ok { // iptables is telling us to drop the packet. stats.IPTablesPreroutingDropped.Increment() return @@ -1183,7 +1183,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().CheckInput(pkt, inNICName); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1990,6 +1990,9 @@ type protocol struct { // eps is keyed by NICID to allow protocol methods to retrieve an endpoint // when handling a packet, by looking at which NIC handled the packet. eps map[tcpip.NICID]*endpoint + + // ICMP types for which the stack's global rate limiting must apply. + icmpRateLimitedTypes map[header.ICMPv6Type]struct{} } ids []uint32 @@ -2001,7 +2004,8 @@ type protocol struct { // Must be accessed using atomic operations. defaultTTL uint32 - fragmentation *fragmentation.Fragmentation + fragmentation *fragmentation.Fragmentation + icmpRateLimiter *stack.ICMPRateLimiter } // Number returns the ipv6 protocol number. @@ -2177,6 +2181,18 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu return proto, !fragMore && fragOffset == 0, true } +// allowICMPReply reports whether an ICMP reply with provided type may +// be sent following the rate mask options and global ICMP rate limiter. +func (p *protocol) allowICMPReply(icmpType header.ICMPv6Type) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + if _, ok := p.mu.icmpRateLimitedTypes[icmpType]; ok { + return p.stack.AllowICMPMessage() + } + return true +} + // calculateNetworkMTU calculates the network-layer payload MTU based on the // link-layer payload MTU and the length of every IPv6 header. // Note that this is different than the Payload Length field of the IPv6 header, @@ -2273,6 +2289,21 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { p.fragmentation = fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock(), p) p.mu.eps = make(map[tcpip.NICID]*endpoint) p.SetDefaultTTL(DefaultTTL) + // Set default ICMP rate limiting to Linux defaults. + // + // Default: 0-1,3-127 (rate limit ICMPv6 errors except Packet Too Big) + // See https://www.kernel.org/doc/Documentation/networking/ip-sysctl.txt. + defaultIcmpTypes := make(map[header.ICMPv6Type]struct{}) + for i := header.ICMPv6Type(0); i < header.ICMPv6EchoRequest; i++ { + switch i { + case header.ICMPv6PacketTooBig: + // Do not rate limit packet too big by default. + default: + defaultIcmpTypes[i] = struct{}{} + } + } + p.mu.icmpRateLimitedTypes = defaultIcmpTypes + return p } } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 74f19900c..e5286081e 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -3522,3 +3522,149 @@ func TestMultiCounterStatsInitialization(t *testing.T) { t.Error(err) } } + +func TestIcmpRateLimit(t *testing.T) { + var ( + host1IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::1").To16()), + PrefixLen: 64, + }, + } + host2IPv6Addr = tcpip.ProtocolAddress{ + Protocol: ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.Address(net.ParseIP("10::2").To16()), + PrefixLen: 64, + }, + } + ) + const icmpBurst = 5 + e := channel.New(1, defaultMTU, tcpip.LinkAddress("")) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + Clock: faketime.NewManualClock(), + }) + s.SetICMPBurst(icmpBurst) + + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddProtocolAddress(nicID, host1IPv6Addr, stack.AddressProperties{}); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %+v, {}): %s", nicID, host1IPv6Addr, err) + } + s.SetRouteTable([]tcpip.Route{ + { + Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + NIC: nicID, + }, + }) + tests := []struct { + name string + createPacket func() buffer.View + check func(*testing.T, *channel.Endpoint, int) + }{ + { + name: "echo", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.ICMPv6MinimumSize + hdr := buffer.NewPrependable(totalLength) + icmpH := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + icmpH.SetIdent(1) + icmpH.SetSequence(1) + icmpH.SetType(header.ICMPv6EchoRequest) + icmpH.SetCode(header.ICMPv6UnusedCode) + icmpH.SetChecksum(0) + icmpH.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ + Header: icmpH, + Src: host2IPv6Addr.AddressWithPrefix.Address, + Dst: host1IPv6Addr.AddressWithPrefix.Address, + })) + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.ICMPv6ProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if !ok { + t.Fatalf("expected echo response, no packet read in endpoint in round %d", round) + } + if got, want := p.Proto, header.IPv6ProtocolNumber; got != want { + t.Errorf("got p.Proto = %d, want = %d", got, want) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply), + )) + }, + }, + { + name: "dst unreachable", + createPacket: func() buffer.View { + totalLength := header.IPv6MinimumSize + header.UDPMinimumSize + hdr := buffer.NewPrependable(totalLength) + udpH := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + udpH.Encode(&header.UDPFields{ + SrcPort: 100, + DstPort: 101, + Length: header.UDPMinimumSize, + }) + + // Calculate the UDP checksum and set it. + sum := header.PseudoHeaderChecksum(udp.ProtocolNumber, host2IPv6Addr.AddressWithPrefix.Address, host1IPv6Addr.AddressWithPrefix.Address, header.UDPMinimumSize) + sum = header.Checksum(nil, sum) + udpH.SetChecksum(^udpH.CalculateChecksum(sum)) + + payloadLength := hdr.UsedLength() + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(payloadLength), + TransportProtocol: header.UDPProtocolNumber, + HopLimit: 1, + SrcAddr: host2IPv6Addr.AddressWithPrefix.Address, + DstAddr: host1IPv6Addr.AddressWithPrefix.Address, + }) + return hdr.View() + }, + check: func(t *testing.T, e *channel.Endpoint, round int) { + p, ok := e.Read() + if round >= icmpBurst { + if ok { + t.Errorf("got packet %x in round %d, expected ICMP rate limit to stop it", p.Pkt.Data().Views(), round) + } + return + } + if !ok { + t.Fatalf("expected unreachable in round %d, no packet read in endpoint", round) + } + checker.IPv6(t, stack.PayloadSince(p.Pkt.NetworkHeader()), + checker.SrcAddr(host1IPv6Addr.AddressWithPrefix.Address), + checker.DstAddr(host2IPv6Addr.AddressWithPrefix.Address), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + )) + }, + }, + } + for _, testCase := range tests { + t.Run(testCase.name, func(t *testing.T) { + for round := 0; round < icmpBurst+1; round++ { + e.InjectInbound(header.IPv6ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: testCase.createPacket().ToVectorisedView(), + })) + testCase.check(t, e, round) + } + }) + } +} diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 068dab7ce..4fb7e9adb 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -160,7 +160,13 @@ func (cn *conn) timedOut(now time.Time) bool { // update the connection tracking state. // // Precondition: cn.mu must be held. -func (cn *conn) updateLocked(tcpHeader header.TCP, hook Hook) { +func (cn *conn) updateLocked(pkt *PacketBuffer, hook Hook) { + if pkt.TransportProtocolNumber != header.TCPProtocolNumber { + return + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + // Update the state of tcb. tcb assumes it's always initialized on the // client. However, we only need to know whether the connection is // established or not, so the client/server distinction isn't important. @@ -209,27 +215,38 @@ type bucket struct { tuples tupleList } +func getTransportHeader(pkt *PacketBuffer) (header.ChecksummableTransport, bool) { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber: + if tcpHeader := header.TCP(pkt.TransportHeader().View()); len(tcpHeader) >= header.TCPMinimumSize { + return tcpHeader, true + } + case header.UDPProtocolNumber: + if udpHeader := header.UDP(pkt.TransportHeader().View()); len(udpHeader) >= header.UDPMinimumSize { + return udpHeader, true + } + } + + return nil, false +} + // packetToTupleID converts packet to a tuple ID. It fails when pkt lacks a valid // TCP header. // // Preconditions: pkt.NetworkHeader() is valid. func packetToTupleID(pkt *PacketBuffer) (tupleID, tcpip.Error) { netHeader := pkt.Network() - if netHeader.TransportProtocol() != header.TCPProtocolNumber { - return tupleID{}, &tcpip.ErrUnknownProtocol{} - } - - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return tupleID{}, &tcpip.ErrUnknownProtocol{} } return tupleID{ srcAddr: netHeader.SourceAddress(), - srcPort: tcpHeader.SourcePort(), + srcPort: transportHeader.SourcePort(), dstAddr: netHeader.DestinationAddress(), - dstPort: tcpHeader.DestinationPort(), - transProto: netHeader.TransportProtocol(), + dstPort: transportHeader.DestinationPort(), + transProto: pkt.TransportProtocolNumber, netProto: pkt.NetworkProtocolNumber, }, nil } @@ -381,8 +398,8 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { return false } - // TODO(gvisor.dev/issue/6168): Support UDP. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + transportHeader, ok := getTransportHeader(pkt) + if !ok { return false } @@ -396,10 +413,6 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { } netHeader := pkt.Network() - tcpHeader := header.TCP(pkt.TransportHeader().View()) - if len(tcpHeader) < header.TCPMinimumSize { - return false - } // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be // validated if checksum offloading is off. It may require IP defrag if the @@ -412,36 +425,31 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { switch hook { case Prerouting, Output: - if conn.manip == manipDestination { - switch dir { - case dirOriginal: - newPort = conn.reply.srcPort - newAddr = conn.reply.srcAddr - case dirReply: - newPort = conn.original.dstPort - newAddr = conn.original.dstAddr - - updateSRCFields = true - } + if conn.manip == manipDestination && dir == dirOriginal { + newPort = conn.reply.srcPort + newAddr = conn.reply.srcAddr + pkt.NatDone = true + } else if conn.manip == manipSource && dir == dirReply { + newPort = conn.original.srcPort + newAddr = conn.original.srcAddr pkt.NatDone = true } case Input, Postrouting: - if conn.manip == manipSource { - switch dir { - case dirOriginal: - newPort = conn.reply.dstPort - newAddr = conn.reply.dstAddr - - updateSRCFields = true - case dirReply: - newPort = conn.original.srcPort - newAddr = conn.original.srcAddr - } + if conn.manip == manipSource && dir == dirOriginal { + newPort = conn.reply.dstPort + newAddr = conn.reply.dstAddr + updateSRCFields = true + pkt.NatDone = true + } else if conn.manip == manipDestination && dir == dirReply { + newPort = conn.original.dstPort + newAddr = conn.original.dstAddr + updateSRCFields = true pkt.NatDone = true } default: panic(fmt.Sprintf("unrecognized hook = %s", hook)) } + if !pkt.NatDone { return false } @@ -449,10 +457,15 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { fullChecksum := false updatePseudoHeader := false switch hook { - case Prerouting, Input: + case Prerouting: + // Packet came from outside the stack so it must have a checksum set + // already. + fullChecksum = true + updatePseudoHeader = true + case Input: case Output, Postrouting: // Calculate the TCP checksum and set it. - if pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { + if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum { updatePseudoHeader = true } else if r.RequiresTXTransportChecksum() { fullChecksum = true @@ -464,7 +477,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { rewritePacket( netHeader, - tcpHeader, + transportHeader, updateSRCFields, fullChecksum, updatePseudoHeader, @@ -479,7 +492,7 @@ func (ct *ConnTrack) handlePacket(pkt *PacketBuffer, hook Hook, r *Route) bool { // Mark the connection as having been used recently so it isn't reaped. conn.lastUsed = time.Now() // Update connection state. - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) return false } @@ -497,8 +510,11 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } - // We only track TCP connections. - if pkt.Network().TransportProtocol() != header.TCPProtocolNumber { + switch pkt.TransportProtocolNumber { + case header.TCPProtocolNumber, header.UDPProtocolNumber: + default: + // TODO(https://gvisor.dev/issue/5915): Track ICMP and other trackable + // connections. return } @@ -510,7 +526,7 @@ func (ct *ConnTrack) maybeInsertNoop(pkt *PacketBuffer, hook Hook) { return } conn := newConn(tid, tid.reply(), manipNone, hook) - conn.updateLocked(header.TCP(pkt.TransportHeader().View()), hook) + conn.updateLocked(pkt, hook) ct.insertConn(conn) } @@ -632,7 +648,7 @@ func (ct *ConnTrack) reapTupleLocked(tuple *tuple, bucket int, now time.Time) bo return true } -func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { // Lookup the connection. The reply's original destination // describes the original address. tid := tupleID{ @@ -640,7 +656,7 @@ func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.Networ srcPort: epID.LocalPort, dstAddr: epID.RemoteAddress, dstPort: epID.RemotePort, - transProto: header.TCPProtocolNumber, + transProto: transProto, netProto: netProto, } conn, _ := ct.connForTID(tid) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index f152c0d83..74c9075b4 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -264,12 +264,62 @@ const ( chainReturn ) -// Check runs pkt through the rules for hook. It returns true when the packet +// CheckPrerouting performs the prerouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPrerouting(pkt *PacketBuffer, addressEP AddressableEndpoint, inNicName string) bool { + return it.check(Prerouting, pkt, nil /* route */, addressEP, inNicName, "" /* outNicName */) +} + +// CheckInput performs the input hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckInput(pkt *PacketBuffer, inNicName string) bool { + return it.check(Input, pkt, nil /* route */, nil /* addressEP */, inNicName, "" /* outNicName */) +} + +// CheckForward performs the forward hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckForward(pkt *PacketBuffer, inNicName, outNicName string) bool { + return it.check(Forward, pkt, nil /* route */, nil /* addressEP */, inNicName, outNicName) +} + +// CheckOutput performs the output hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckOutput(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Output, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// CheckPostrouting performs the postrouting hook on the packet. +// +// Returns true iff the packet may continue traversing the stack; the packet +// must be dropped if false is returned. +// +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) CheckPostrouting(pkt *PacketBuffer, r *Route, outNicName string) bool { + return it.check(Postrouting, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName) +} + +// check runs pkt through the rules for hook. It returns true when the packet // should continue traversing the network stack and false when it should be // dropped. // -// Precondition: pkt.NetworkHeader is set. -func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool { +// Precondition: The packet's network and transport header must be set. +func (it *IPTables) check(hook Hook, pkt *PacketBuffer, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) bool { if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber { return true } @@ -300,7 +350,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr table = it.v4Tables[tableID] } ruleIdx := table.BuiltinChains[hook] - switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { // If the table returns Accept, move on to the next table. case chainAccept: continue @@ -311,7 +361,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, r *Route, preroutingAddr // Any Return from a built-in chain means we have to // call the underflow. underflow := table.Rules[table.Underflows[hook]] - switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, preroutingAddr); v { + switch v, _ := underflow.Target.Action(pkt, &it.connections, hook, r, addressEP); v { case RuleAccept: continue case RuleDrop: @@ -375,19 +425,35 @@ func (it *IPTables) startReaper(interval time.Duration) { }() } -// CheckPackets runs pkts through the rules for hook and returns a map of packets that -// should not go forward. +// CheckOutputPackets performs the output hook on the packets. // -// Preconditions: -// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. -// * pkt.NetworkHeader is not nil. +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckOutputPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Output, pkts, r, outNicName) +} + +// CheckPostroutingPackets performs the postrouting hook on the packets. +// +// Returns a map of packets that must be dropped. +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) CheckPostroutingPackets(pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { + return it.checkPackets(Postrouting, pkts, r, outNicName) +} + +// checkPackets runs pkts through the rules for hook and returns a map of +// packets that should not go forward. // // NOTE: unlike the Check API the returned map contains packets that should be // dropped. -func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { +// +// Precondition: The packets' network and transport header must be set. +func (it *IPTables) checkPackets(hook Hook, pkts PacketBufferList, r *Route, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) { for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { if !pkt.NatDone { - if ok := it.Check(hook, pkt, r, "", inNicName, outNicName); !ok { + if ok := it.check(hook, pkt, r, nil /* addressEP */, "" /* inNicName */, outNicName); !ok { if drop == nil { drop = make(map[*PacketBuffer]struct{}) } @@ -407,11 +473,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, r *Route, inN // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict { +func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) chainVerdict { // Start from ruleIdx and walk the list of rules until a rule gives us // a verdict. for ruleIdx < len(table.Rules) { - switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, r, addressEP, inNicName, outNicName); verdict { case RuleAccept: return chainAccept @@ -428,7 +494,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId ruleIdx++ continue } - switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, preroutingAddr, inNicName, outNicName); verdict { + switch verdict := it.checkChain(hook, pkt, table, jumpTo, r, addressEP, inNicName, outNicName); verdict { case chainAccept: return chainAccept case chainDrop: @@ -454,7 +520,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId // Preconditions: // * pkt is a IPv4 packet of at least length header.IPv4MinimumSize. // * pkt.NetworkHeader is not nil. -func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) { +func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, r *Route, addressEP AddressableEndpoint, inNicName, outNicName string) (RuleVerdict, int) { rule := table.Rules[ruleIdx] // Check whether the packet matches the IP header filter. @@ -477,16 +543,16 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx } // All the matchers matched, so run the target. - return rule.Target.Action(pkt, &it.connections, hook, r, preroutingAddr) + return rule.Target.Action(pkt, &it.connections, hook, r, addressEP) } // OriginalDst returns the original destination of redirected connections. It // returns an error if the connection doesn't exist or isn't redirected. -func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { +func (it *IPTables) OriginalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) { it.mu.RLock() defer it.mu.RUnlock() if !it.modified { return "", 0, &tcpip.ErrNotConnected{} } - return it.connections.originalDst(epID, netProto) + return it.connections.originalDst(epID, netProto, transProto) } diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 96cc899bb..e8806ebdb 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -29,7 +29,7 @@ type AcceptTarget struct { } // Action implements Target.Action. -func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleAccept, 0 } @@ -40,7 +40,7 @@ type DropTarget struct { } // Action implements Target.Action. -func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleDrop, 0 } @@ -52,7 +52,7 @@ type ErrorTarget struct { } // Action implements Target.Action. -func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") return RuleDrop, 0 } @@ -67,7 +67,7 @@ type UserChainTarget struct { } // Action implements Target.Action. -func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { panic("UserChainTarget should never be called.") } @@ -79,7 +79,7 @@ type ReturnTarget struct { } // Action implements Target.Action. -func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) { +func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) { return RuleReturn, 0 } @@ -97,7 +97,7 @@ type RedirectTarget struct { } // Action implements Target.Action. -func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, addressEP AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if rt.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -117,6 +117,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. + var address tcpip.Address switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { @@ -125,7 +126,8 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r address = header.IPv6Loopback } case Prerouting: - // No-op, as address is already set correctly. + // addressEP is expected to be set for the prerouting hook. + address = addressEP.MainAddress().Address default: panic("redirect target is supported only on output and prerouting hooks") } @@ -180,7 +182,7 @@ type SNATTarget struct { } // Action implements Target.Action. -func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, address tcpip.Address) (RuleVerdict, int) { +func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Route, _ AddressableEndpoint) (RuleVerdict, int) { // Sanity check. if st.NetworkProtocol != pkt.NetworkProtocolNumber { panic(fmt.Sprintf( @@ -206,34 +208,28 @@ func (st *SNATTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, r *Rou panic(fmt.Sprintf("%s unrecognized", hook)) } - switch protocol := pkt.TransportProtocolNumber; protocol { - case header.UDPProtocolNumber: - // Only calculate the checksum if offloading isn't supported. - requiresChecksum := r.RequiresTXTransportChecksum() - rewritePacket( - pkt.Network(), - header.UDP(pkt.TransportHeader().View()), - true, /* updateSRCFields */ - requiresChecksum, - requiresChecksum, - st.Port, - st.Addr, - ) - - pkt.NatDone = true - case header.TCPProtocolNumber: - if ct == nil { - return RuleAccept, 0 + port := st.Port + + if port == 0 { + switch protocol := pkt.TransportProtocolNumber; protocol { + case header.UDPProtocolNumber: + if port == 0 { + port = header.UDP(pkt.TransportHeader().View()).SourcePort() + } + case header.TCPProtocolNumber: + if port == 0 { + port = header.TCP(pkt.TransportHeader().View()).SourcePort() + } } + } - // Set up conection for matching NAT rule. Only the first - // packet of the connection comes here. Other packets will be - // manipulated in connection tracking. - if conn := ct.insertSNATConn(pkt, hook, st.Port, st.Addr); conn != nil { - ct.handlePacket(pkt, hook, r) - } - default: - return RuleDrop, 0 + // Set up conection for matching NAT rule. Only the first packet of the + // connection comes here. Other packets will be manipulated in connection + // tracking. + // + // Does nothing if the protocol does not support connection tracking. + if conn := ct.insertSNATConn(pkt, hook, port, st.Addr); conn != nil { + ct.handlePacket(pkt, hook, r) } return RuleAccept, 0 diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 66e5f22ac..976194124 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -352,5 +352,5 @@ type Target interface { // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. - Action(*PacketBuffer, *ConnTrack, Hook, *Route, tcpip.Address) (RuleVerdict, int) + Action(*PacketBuffer, *ConnTrack, Hook, *Route, AddressableEndpoint) (RuleVerdict, int) } diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index b9280c2de..bf248ef20 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -335,9 +335,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // tell if a noop connection should be inserted at Input hook. Once conntrack // redefines the manipulation field as mutable, we won't need the special noop // connection. - if pk.NatDone { - newPk.NatDone = true - } + newPk.NatDone = pk.NatDone return newPk } @@ -347,7 +345,7 @@ func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { // The returned packet buffer will have the network and transport headers // set if the original packet buffer did. func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBuffer { - newPkt := NewPacketBuffer(PacketBufferOptions{ + newPk := NewPacketBuffer(PacketBufferOptions{ ReserveHeaderBytes: reservedHeaderBytes, Data: PayloadSince(pk.NetworkHeader()).ToVectorisedView(), IsForwardedPacket: true, @@ -355,21 +353,28 @@ func (pk *PacketBuffer) DeepCopyForForwarding(reservedHeaderBytes int) *PacketBu { consumeBytes := pk.NetworkHeader().View().Size() - if _, consumed := newPkt.NetworkHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.NetworkHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume network header %d bytes from new packet", consumeBytes)) } - newPkt.NetworkProtocolNumber = pk.NetworkProtocolNumber + newPk.NetworkProtocolNumber = pk.NetworkProtocolNumber } { consumeBytes := pk.TransportHeader().View().Size() - if _, consumed := newPkt.TransportHeader().Consume(consumeBytes); !consumed { + if _, consumed := newPk.TransportHeader().Consume(consumeBytes); !consumed { panic(fmt.Sprintf("expected to consume transport header %d bytes from new packet", consumeBytes)) } - newPkt.TransportProtocolNumber = pk.TransportProtocolNumber + newPk.TransportProtocolNumber = pk.TransportProtocolNumber } - return newPkt + // TODO(gvisor.dev/issue/5696): reimplement conntrack so that no need to + // maintain this flag in the packet. Currently conntrack needs this flag to + // tell if a noop connection should be inserted at Input hook. Once conntrack + // redefines the manipulation field as mutable, we won't need the special noop + // connection. + newPk.NatDone = pk.NatDone + + return newPk } // headerInfo stores metadata about a header in a packet. diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index c5e896295..d45a2c05c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -1269,6 +1269,8 @@ type TransportProtocolNumber uint32 type NetworkProtocolNumber uint32 // A StatCounter keeps track of a statistic. +// +// +stateify savable type StatCounter struct { count atomicbitops.AlignedAtomicUint64 } @@ -1995,6 +1997,8 @@ type Stats struct { } // ReceiveErrors collects packet receive errors within transport endpoint. +// +// +stateify savable type ReceiveErrors struct { // ReceiveBufferOverflow is the number of received packets dropped // due to the receive buffer being full. @@ -2012,8 +2016,10 @@ type ReceiveErrors struct { ChecksumErrors StatCounter } -// SendErrors collects packet send errors within the transport layer for -// an endpoint. +// SendErrors collects packet send errors within the transport layer for an +// endpoint. +// +// +stateify savable type SendErrors struct { // SendToNetworkFailed is the number of packets failed to be written to // the network endpoint. @@ -2024,6 +2030,8 @@ type SendErrors struct { } // ReadErrors collects segment read errors from an endpoint read call. +// +// +stateify savable type ReadErrors struct { // ReadClosed is the number of received packet drops because the endpoint // was shutdown for read. @@ -2039,6 +2047,8 @@ type ReadErrors struct { } // WriteErrors collects packet write errors from an endpoint write call. +// +// +stateify savable type WriteErrors struct { // WriteClosed is the number of packet drops because the endpoint // was shutdown for write. @@ -2054,6 +2064,8 @@ type WriteErrors struct { } // TransportEndpointStats collects statistics about the endpoint. +// +// +stateify savable type TransportEndpointStats struct { // PacketsReceived is the number of successful packet receives. PacketsReceived StatCounter diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 181ef799e..7c998eaae 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -34,12 +34,16 @@ go_test( "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/network/arp", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/tests/utils", "//pkg/tcpip/testutil", + "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 28b49c6be..bdf4a64b9 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -15,19 +15,24 @@ package iptables_test import ( + "bytes" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/arp" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/tests/utils" "gvisor.dev/gvisor/pkg/tcpip/testutil" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" ) type inputIfNameMatcher struct { @@ -1156,3 +1161,286 @@ func TestInputHookWithLocalForwarding(t *testing.T) { }) } } + +func TestSNAT(t *testing.T) { + const listenPort = 8080 + + type endpointAndAddresses struct { + serverEP tcpip.Endpoint + serverAddr tcpip.Address + serverReadableCH chan struct{} + + clientEP tcpip.Endpoint + clientAddr tcpip.Address + clientReadableCH chan struct{} + + nattedClientAddr tcpip.Address + } + + newEP := func(t *testing.T, s *stack.Stack, transProto tcpip.TransportProtocolNumber, netProto tcpip.NetworkProtocolNumber) (tcpip.Endpoint, chan struct{}) { + t.Helper() + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + t.Cleanup(func() { + wq.EventUnregister(&we) + }) + + ep, err := s.NewEndpoint(transProto, netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", transProto, netProto, err) + } + t.Cleanup(ep.Close) + + return ep, ch + } + + tests := []struct { + name string + epAndAddrs func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses + }{ + { + name: "IPv4 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, false /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv4.ProtocolNumber, Addr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, false /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, false, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv4Addr.AddressWithPrefix.Address, + } + }, + }, + { + name: "IPv6 host1 server with host2 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack, proto tcpip.TransportProtocolNumber) endpointAndAddresses { + t.Helper() + + ipt := routerStack.IPTables() + filter := ipt.GetTable(stack.NATID, true /* ipv6 */) + ruleIdx := filter.BuiltinChains[stack.Postrouting] + filter.Rules[ruleIdx].Filter = stack.IPHeaderFilter{OutputInterface: utils.RouterNIC1Name} + filter.Rules[ruleIdx].Target = &stack.SNATTarget{NetworkProtocol: ipv6.ProtocolNumber, Addr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} + if err := ipt.ReplaceTable(stack.NATID, filter, true /* ipv6 */); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.NATID, true, err) + } + + ep1, ep1WECH := newEP(t, host1Stack, proto, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host2Stack, proto, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: utils.Host1IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: utils.Host2IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + + nattedClientAddr: utils.RouterNIC1IPv6Addr.AddressWithPrefix.Address, + } + }, + }, + } + + subTests := []struct { + name string + proto tcpip.TransportProtocolNumber + expectedConnectErr tcpip.Error + setupServer func(t *testing.T, ep tcpip.Endpoint) + setupServerConn func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) + needRemoteAddr bool + }{ + { + name: "UDP", + proto: udp.ProtocolNumber, + expectedConnectErr: nil, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, _ <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + if err := ep.Connect(clientAddr); err != nil { + t.Fatalf("ep.Connect(%#v): %s", clientAddr, err) + } + return nil, nil + }, + needRemoteAddr: true, + }, + { + name: "TCP", + proto: tcp.ProtocolNumber, + expectedConnectErr: &tcpip.ErrConnectStarted{}, + setupServer: func(t *testing.T, ep tcpip.Endpoint) { + t.Helper() + + if err := ep.Listen(1); err != nil { + t.Fatalf("ep.Listen(1): %s", err) + } + }, + setupServerConn: func(t *testing.T, ep tcpip.Endpoint, ch <-chan struct{}, clientAddr tcpip.FullAddress) (tcpip.Endpoint, chan struct{}) { + t.Helper() + + var addr tcpip.FullAddress + for { + newEP, wq, err := ep.Accept(&addr) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Accept(_): %s", err) + } + if diff := cmp.Diff(clientAddr, addr, checker.IgnoreCmpPath( + "NIC", + )); diff != "" { + t.Errorf("accepted address mismatch (-want +got):\n%s", diff) + } + + we, newCH := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.ReadableEvents) + return newEP, newCH + } + }, + needRemoteAddr: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol, tcp.NewProtocol}, + } + + host1Stack := stack.New(stackOpts) + routerStack := stack.New(stackOpts) + host2Stack := stack.New(stackOpts) + utils.SetupRoutedStacks(t, host1Stack, routerStack, host2Stack) + + epsAndAddrs := test.epAndAddrs(t, host1Stack, routerStack, host2Stack, subTest.proto) + serverAddr := tcpip.FullAddress{Addr: epsAndAddrs.serverAddr, Port: listenPort} + if err := epsAndAddrs.serverEP.Bind(serverAddr); err != nil { + t.Fatalf("epsAndAddrs.serverEP.Bind(%#v): %s", serverAddr, err) + } + clientAddr := tcpip.FullAddress{Addr: epsAndAddrs.clientAddr} + if err := epsAndAddrs.clientEP.Bind(clientAddr); err != nil { + t.Fatalf("epsAndAddrs.clientEP.Bind(%#v): %s", clientAddr, err) + } + + if subTest.setupServer != nil { + subTest.setupServer(t, epsAndAddrs.serverEP) + } + { + err := epsAndAddrs.clientEP.Connect(serverAddr) + if diff := cmp.Diff(subTest.expectedConnectErr, err); diff != "" { + t.Fatalf("unexpected error from epsAndAddrs.clientEP.Connect(%#v), (-want, +got):\n%s", serverAddr, diff) + } + } + nattedClientAddr := tcpip.FullAddress{Addr: epsAndAddrs.nattedClientAddr} + if addr, err := epsAndAddrs.clientEP.GetLocalAddress(); err != nil { + t.Fatalf("epsAndAddrs.clientEP.GetLocalAddress(): %s", err) + } else { + nattedClientAddr.Port = addr.Port + } + + serverEP := epsAndAddrs.serverEP + serverCH := epsAndAddrs.serverReadableCH + if ep, ch := subTest.setupServerConn(t, serverEP, serverCH, nattedClientAddr); ep != nil { + defer ep.Close() + serverEP = ep + serverCH = ch + } + + write := func(ep tcpip.Endpoint, data []byte) { + t.Helper() + + var r bytes.Reader + r.Reset(data) + var wOpts tcpip.WriteOptions + n, err := ep.Write(&r, wOpts) + if err != nil { + t.Fatalf("ep.Write(_, %#v): %s", wOpts, err) + } + if want := int64(len(data)); n != want { + t.Fatalf("got ep.Write(_, %#v) = (%d, _), want = (%d, _)", wOpts, n, want) + } + } + + read := func(ch chan struct{}, ep tcpip.Endpoint, data []byte, expectedFrom tcpip.FullAddress) { + t.Helper() + + var buf bytes.Buffer + var res tcpip.ReadResult + for { + var err tcpip.Error + opts := tcpip.ReadOptions{NeedRemoteAddr: subTest.needRemoteAddr} + res, err = ep.Read(&buf, opts) + if _, ok := err.(*tcpip.ErrWouldBlock); ok { + <-ch + continue + } + if err != nil { + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) + } + break + } + + readResult := tcpip.ReadResult{ + Count: len(data), + Total: len(data), + } + if subTest.needRemoteAddr { + readResult.RemoteAddr = expectedFrom + } + if diff := cmp.Diff(readResult, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + + if t.Failed() { + t.FailNow() + } + } + + { + data := []byte{1, 2, 3, 4} + write(epsAndAddrs.clientEP, data) + read(serverCH, serverEP, data, nattedClientAddr) + } + + { + data := []byte{5, 6, 7, 8, 9, 10, 11, 12} + write(serverEP, data) + read(epsAndAddrs.clientReadableCH, epsAndAddrs.clientEP, data, serverAddr) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/tests/utils/utils.go b/pkg/tcpip/tests/utils/utils.go index 947bcc7b1..c69410859 100644 --- a/pkg/tcpip/tests/utils/utils.go +++ b/pkg/tcpip/tests/utils/utils.go @@ -40,6 +40,14 @@ const ( Host2NICID = 4 ) +// Common NIC names used by tests. +const ( + Host1NICName = "host1NIC" + RouterNIC1Name = "routerNIC1" + RouterNIC2Name = "routerNIC2" + Host2NICName = "host2NIC" +) + // Common link addresses used by tests. const ( LinkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") @@ -211,17 +219,29 @@ func SetupRoutedStacks(t *testing.T, host1Stack, routerStack, host2Stack *stack. host1NIC, routerNIC1 := pipe.New(LinkAddr1, LinkAddr2) routerNIC2, host2NIC := pipe.New(LinkAddr3, LinkAddr4) - if err := host1Stack.CreateNIC(Host1NICID, NewEthernetEndpoint(host1NIC)); err != nil { - t.Fatalf("host1Stack.CreateNIC(%d, _): %s", Host1NICID, err) + { + opts := stack.NICOptions{Name: Host1NICName} + if err := host1Stack.CreateNICWithOptions(Host1NICID, NewEthernetEndpoint(host1NIC), opts); err != nil { + t.Fatalf("host1Stack.CreateNICWithOptions(%d, _, %#v): %s", Host1NICID, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID1, NewEthernetEndpoint(routerNIC1)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID1, err) + { + opts := stack.NICOptions{Name: RouterNIC1Name} + if err := routerStack.CreateNICWithOptions(RouterNICID1, NewEthernetEndpoint(routerNIC1), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID1, opts, err) + } } - if err := routerStack.CreateNIC(RouterNICID2, NewEthernetEndpoint(routerNIC2)); err != nil { - t.Fatalf("routerStack.CreateNIC(%d, _): %s", RouterNICID2, err) + { + opts := stack.NICOptions{Name: RouterNIC2Name} + if err := routerStack.CreateNICWithOptions(RouterNICID2, NewEthernetEndpoint(routerNIC2), opts); err != nil { + t.Fatalf("routerStack.CreateNICWithOptions(%d, _, %#v): %s", RouterNICID2, opts, err) + } } - if err := host2Stack.CreateNIC(Host2NICID, NewEthernetEndpoint(host2NIC)); err != nil { - t.Fatalf("host2Stack.CreateNIC(%d, _): %s", Host2NICID, err) + { + opts := stack.NICOptions{Name: Host2NICName} + if err := host2Stack.CreateNICWithOptions(Host2NICID, NewEthernetEndpoint(host2NIC), opts); err != nil { + t.Fatalf("host2Stack.CreateNICWithOptions(%d, _, %#v): %s", Host2NICID, opts, err) + } } if err := routerStack.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index bbc0e3ecc..4718ec4ec 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,8 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/ports", "//pkg/tcpip/stack", + "//pkg/tcpip/transport", + "//pkg/tcpip/transport/internal/network", "//pkg/tcpip/transport/raw", "//pkg/tcpip/transport/tcp", "//pkg/waiter", diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 1e519085d..bb0db9f70 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,7 @@ package icmp import ( + "fmt" "io" "time" @@ -24,6 +25,8 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" + "gvisor.dev/gvisor/pkg/tcpip/transport/internal/network" "gvisor.dev/gvisor/pkg/waiter" ) @@ -35,15 +38,6 @@ type icmpPacket struct { receivedAt time.Time `state:".(int64)"` } -type endpointState int - -const ( - stateInitial endpointState = iota - stateBound - stateConnected - stateClosed -) - // endpoint represents an ICMP endpoint. This struct serves as the interface // between users of the endpoint and the protocol implementation; it is legal to // have concurrent goroutines make calls into the endpoint, they are properly @@ -51,14 +45,17 @@ const ( // // +stateify savable type endpoint struct { - stack.TransportEndpointInfo tcpip.DefaultSocketOptionsHandler // The following fields are initialized at creation time and are // immutable. stack *stack.Stack `state:"manual"` + transProto tcpip.TransportProtocolNumber waiterQueue *waiter.Queue uniqueID uint64 + net network.Endpoint + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. @@ -70,38 +67,23 @@ type endpoint struct { // The following fields are protected by the mu mutex. mu sync.RWMutex `state:"nosave"` - // shutdownFlags represent the current shutdown state of the endpoint. - shutdownFlags tcpip.ShutdownFlags - state endpointState - route *stack.Route `state:"manual"` - ttl uint8 - stats tcpip.TransportEndpointStats `state:"nosave"` - - // owner is used to get uid and gid of the packet. - owner tcpip.PacketOwner - - // ops is used to get socket level options. - ops tcpip.SocketOptions - // frozen indicates if the packets should be delivered to the endpoint // during restore. frozen bool + ident uint16 } func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, tcpip.Error) { ep := &endpoint{ - stack: s, - TransportEndpointInfo: stack.TransportEndpointInfo{ - NetProto: netProto, - TransProto: transProto, - }, + stack: s, + transProto: transProto, waiterQueue: waiterQueue, - state: stateInitial, uniqueID: s.UniqueID(), } ep.ops.InitHandler(ep, ep.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) ep.ops.SetSendBufferSize(32*1024, false /* notify */) ep.ops.SetReceiveBufferSize(32*1024, false /* notify */) + ep.net.Init(s, netProto, transProto, &ep.ops) // Override with stack defaults. var ss tcpip.SendBufferSizeOption @@ -128,35 +110,40 @@ func (e *endpoint) Abort() { // Close puts the endpoint in a closed state and frees all resources // associated with it. func (e *endpoint) Close() { - e.mu.Lock() - e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite - switch e.state { - case stateBound, stateConnected: - bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) - e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice) - } + notify := func() bool { + e.mu.Lock() + defer e.mu.Unlock() + + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateClosed: + return false + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + info := e.net.Info() + info.ID.LocalPort = e.ident + e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{info.NetProto}, e.transProto, info.ID, e, ports.Flags{}, tcpip.NICID(e.ops.GetBindToDevice())) + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } - // Close the receive list and drain it. - e.rcvMu.Lock() - e.rcvClosed = true - e.rcvBufSize = 0 - for !e.rcvList.Empty() { - p := e.rcvList.Front() - e.rcvList.Remove(p) - } - e.rcvMu.Unlock() + e.net.Shutdown() + e.net.Close() - if e.route != nil { - e.route.Release() - e.route = nil - } - - // Update the state. - e.state = stateClosed + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } - e.mu.Unlock() + return true + }() - e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + if notify { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.ReadableEvents | waiter.WritableEvents) + } } // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. @@ -164,7 +151,7 @@ func (*endpoint) ModerateRecvBuf(int) {} // SetOwner implements tcpip.Endpoint.SetOwner. func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { - e.owner = owner + e.net.SetOwner(owner) } // Read implements tcpip.Endpoint.Read. @@ -214,13 +201,12 @@ func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult // // Returns true for retry if preparation should be retried. // +checklocks:e.mu -func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { - switch e.state { - case stateInitial: - case stateConnected: +func (e *endpoint) prepareForWriteInner(to *tcpip.FullAddress) (retry bool, err tcpip.Error) { + switch e.net.State() { + case transport.DatagramEndpointStateInitial: + case transport.DatagramEndpointStateConnected: return false, nil - - case stateBound: + case transport.DatagramEndpointStateBound: if to == nil { return false, &tcpip.ErrDestinationRequired{} } @@ -235,7 +221,7 @@ func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err tcpip // The state changed when we released the shared locked and re-acquired // it in exclusive mode. Try again. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return true, nil } @@ -270,27 +256,15 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return n, err } -func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { - // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) - if opts.More { - return 0, &tcpip.ErrInvalidOptionValue{} - } - - to := opts.To - +func (e *endpoint) prepareForWrite(opts tcpip.WriteOptions) (network.WriteContext, uint16, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - // If we've shutdown with SHUT_WR we are in an invalid state for sending. - if e.shutdownFlags&tcpip.ShutdownWrite != 0 { - return 0, &tcpip.ErrClosedForSend{} - } - // Prepare for write. for { - retry, err := e.prepareForWrite(to) + retry, err := e.prepareForWriteInner(opts.To) if err != nil { - return 0, err + return network.WriteContext{}, 0, err } if !retry { @@ -298,36 +272,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } } - route := e.route - if to != nil { - // Reject destination address if it goes through a different - // NIC than the endpoint was bound to. - nicID := to.NIC - if nicID == 0 { - nicID = tcpip.NICID(e.ops.GetBindToDevice()) - } - if e.BindNICID != 0 { - if nicID != 0 && nicID != e.BindNICID { - return 0, &tcpip.ErrNoRoute{} - } - - nicID = e.BindNICID - } - - dst, netProto, err := e.checkV4MappedLocked(*to) - if err != nil { - return 0, err - } - - // Find the endpoint. - r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) - if err != nil { - return 0, err - } - defer r.Release() + ctx, err := e.net.AcquireContextForWrite(opts) + return ctx, e.ident, err +} - route = r +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcpip.Error) { + ctx, ident, err := e.prepareForWrite(opts) + if err != nil { + return 0, err } + defer ctx.Release() // TODO(https://gvisor.dev/issue/6538): Avoid this allocation. v := make([]byte, p.Len()) @@ -335,17 +289,18 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp return 0, &tcpip.ErrBadBuffer{} } - var err tcpip.Error - switch e.NetProto { + switch netProto, pktInfo := e.net.NetProto(), ctx.PacketInfo(); netProto { case header.IPv4ProtocolNumber: - err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + if err := send4(e.stack, &ctx, ident, v, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } case header.IPv6ProtocolNumber: - err = send6(route, e.ID.LocalPort, v, e.ttl) - } - - if err != nil { - return 0, err + if err := send6(e.stack, &ctx, ident, v, pktInfo.LocalAddress, pktInfo.RemoteAddress, pktInfo.MaxHeaderLength); err != nil { + return 0, err + } + default: + panic(fmt.Sprintf("unhandled network protocol = %d", netProto)) } return int64(len(v)), nil @@ -358,24 +313,17 @@ func (e *endpoint) HasNIC(id int32) bool { return e.stack.HasNIC(tcpip.NICID(id)) } -// SetSockOpt sets a socket option. -func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error { - return nil +// SetSockOpt implements tcpip.Endpoint. +func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error { + return e.net.SetSockOpt(opt) } -// SetSockOptInt sets a socket option. Currently not supported. +// SetSockOptInt implements tcpip.Endpoint. func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) tcpip.Error { - switch opt { - case tcpip.TTLOption: - e.mu.Lock() - e.ttl = uint8(v) - e.mu.Unlock() - - } - return nil + return e.net.SetSockOptInt(opt, v) } -// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +// GetSockOptInt implements tcpip.Endpoint. func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -388,31 +336,24 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { e.rcvMu.Unlock() return v, nil - case tcpip.TTLOption: - e.rcvMu.Lock() - v := int(e.ttl) - e.rcvMu.Unlock() - return v, nil - default: - return -1, &tcpip.ErrUnknownProtocolOption{} + return e.net.GetSockOptInt(opt) } } -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (*endpoint) GetSockOpt(tcpip.GettableSocketOption) tcpip.Error { - return &tcpip.ErrUnknownProtocolOption{} +// GetSockOpt implements tcpip.Endpoint. +func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { + return e.net.GetSockOpt(opt) } -func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) tcpip.Error { +func send4(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv4MinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv4MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv4MinimumSize + int(maxHeaderLength), }) - pkt.Owner = owner icmpv4 := header.ICMPv4(pkt.TransportHeader().Push(header.ICMPv4MinimumSize)) pkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber @@ -427,36 +368,31 @@ func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpi return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V4.PacketsSent.EchoRequest - icmpv4.SetChecksum(0) icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) - pkt.Data().AppendView(data) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V4.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V4.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Error { +func send6(s *stack.Stack, ctx *network.WriteContext, ident uint16, data buffer.View, src, dst tcpip.Address, maxHeaderLength uint16) tcpip.Error { if len(data) < header.ICMPv6EchoMinimumSize { return &tcpip.ErrInvalidEndpointState{} } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: header.ICMPv6MinimumSize + int(r.MaxHeaderLength()), + ReserveHeaderBytes: header.ICMPv6MinimumSize + int(maxHeaderLength), }) icmpv6 := header.ICMPv6(pkt.TransportHeader().Push(header.ICMPv6MinimumSize)) @@ -469,43 +405,31 @@ func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) tcpip.Erro if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { return &tcpip.ErrInvalidEndpointState{} } - // Because this icmp endpoint is implemented in the transport layer, we can - // only increment the 'stack-wide' stats but we can't increment the - // 'per-NetworkEndpoint' stats. - sentStat := r.Stats().ICMP.V6.PacketsSent.EchoRequest pkt.Data().AppendView(data) dataRange := pkt.Data().AsRange() icmpv6.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{ Header: icmpv6, - Src: r.LocalAddress(), - Dst: r.RemoteAddress(), + Src: src, + Dst: dst, PayloadCsum: dataRange.Checksum(), PayloadLen: dataRange.Size(), })) - if ttl == 0 { - ttl = r.DefaultTTL() - } + // Because this icmp endpoint is implemented in the transport layer, we can + // only increment the 'stack-wide' stats but we can't increment the + // 'per-NetworkEndpoint' stats. + stats := s.Stats().ICMP.V6.PacketsSent - if err := r.WritePacket(stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, pkt); err != nil { - r.Stats().ICMP.V6.PacketsSent.Dropped.Increment() + if err := ctx.WritePacket(pkt, false /* headerIncluded */); err != nil { + stats.Dropped.Increment() + return err } - sentStat.Increment() + stats.EchoRequest.Increment() return nil } -// checkV4MappedLocked determines the effective network protocol and converts -// addr to its canonical form. -func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, tcpip.Error) { - unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) - if err != nil { - return tcpip.FullAddress{}, 0, err - } - return unwrapped, netProto, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() tcpip.Error { return &tcpip.ErrNotSupported{} @@ -516,59 +440,21 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - nicID := addr.NIC - localPort := uint16(0) - switch e.state { - case stateInitial: - case stateBound, stateConnected: - localPort = e.ID.LocalPort - if e.BindNICID == 0 { - break - } + err := e.net.ConnectAndThen(addr, func(netProto tcpip.NetworkProtocolNumber, previousID, nextID stack.TransportEndpointID) tcpip.Error { + nextID.LocalPort = e.ident - if nicID != 0 && nicID != e.BindNICID { - return &tcpip.ErrInvalidEndpointState{} + nextID, err := e.registerWithStack(netProto, nextID) + if err != nil { + return err } - nicID = e.BindNICID - default: - return &tcpip.ErrInvalidEndpointState{} - } - - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) - if err != nil { - return err - } - - id := stack.TransportEndpointID{ - LocalAddress: r.LocalAddress(), - LocalPort: localPort, - RemoteAddress: r.RemoteAddress(), - } - - // Even if we're connected, this endpoint can still be used to send - // packets on a different network protocol, so we register both even if - // v6only is set to false and this is an ipv6 endpoint. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - id, err = e.registerWithStack(nicID, netProtos, id) + e.ident = nextID.LocalPort + return nil + }) if err != nil { - r.Release() return err } - e.ID = id - e.route = r - e.RegisterNICID = nicID - - e.state = stateConnected - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -586,10 +472,19 @@ func (*endpoint) ConnectEndpoint(tcpip.Endpoint) tcpip.Error { func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - e.shutdownFlags |= flags - if e.state != stateConnected { + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: return &tcpip.ErrNotConnected{} + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + default: + panic(fmt.Sprintf("unhandled state = %s", state)) + } + + if flags&tcpip.ShutdownWrite != 0 { + if err := e.net.Shutdown(); err != nil { + return err + } } if flags&tcpip.ShutdownRead != 0 { @@ -616,19 +511,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi return nil, nil, &tcpip.ErrNotSupported{} } -func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { +func (e *endpoint) registerWithStack(netProto tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) { bindToDevice := tcpip.NICID(e.ops.GetBindToDevice()) if id.LocalPort != 0 { // The endpoint already has a local port, just attempt to // register it. - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) - return id, err + return id, e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) } // We need to find a port for the endpoint. _, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) { id.LocalPort = p - err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice) + err := e.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{netProto}, e.transProto, id, e, ports.Flags{}, bindToDevice) switch err.(type) { case nil: return true, nil @@ -645,42 +539,27 @@ func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkPro func (e *endpoint) bindLocked(addr tcpip.FullAddress) tcpip.Error { // Don't allow binding once endpoint is not in the initial state // anymore. - if e.state != stateInitial { + if e.net.State() != transport.DatagramEndpointStateInitial { return &tcpip.ErrInvalidEndpointState{} } - addr, netProto, err := e.checkV4MappedLocked(addr) - if err != nil { - return err - } - - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. - netProtos := []tcpip.NetworkProtocolNumber{netProto} - - if len(addr.Addr) != 0 { - // A local address was specified, verify that it's valid. - if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { - return &tcpip.ErrBadLocalAddress{} + err := e.net.BindAndThen(addr, func(boundNetProto tcpip.NetworkProtocolNumber, boundAddr tcpip.Address) tcpip.Error { + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err := e.registerWithStack(boundNetProto, id) + if err != nil { + return err } - } - id := stack.TransportEndpointID{ - LocalPort: addr.Port, - LocalAddress: addr.Addr, - } - id, err = e.registerWithStack(addr.NIC, netProtos, id) + e.ident = id.LocalPort + return nil + }) if err != nil { return err } - e.ID = id - e.RegisterNICID = addr.NIC - - // Mark endpoint as bound. - e.state = stateBound - e.rcvMu.Lock() e.rcvReady = true e.rcvMu.Unlock() @@ -692,7 +571,7 @@ func (e *endpoint) isBroadcastOrMulticast(nicID tcpip.NICID, addr tcpip.Address) return addr == header.IPv4Broadcast || header.IsV4MulticastAddress(addr) || header.IsV6MulticastAddress(addr) || - e.stack.IsSubnetBroadcast(nicID, e.NetProto, addr) + e.stack.IsSubnetBroadcast(nicID, e.net.NetProto(), addr) } // Bind binds the endpoint to a specific local address and port. @@ -705,15 +584,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { e.mu.Lock() defer e.mu.Unlock() - err := e.bindLocked(addr) - if err != nil { - return err - } - - e.BindNICID = addr.NIC - e.BindAddr = addr.Addr - - return nil + return e.bindLocked(addr) } // GetLocalAddress returns the address to which the endpoint is bound. @@ -721,11 +592,9 @@ func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.LocalAddress, - Port: e.ID.LocalPort, - }, nil + addr := e.net.GetLocalAddress() + addr.Port = e.ident + return addr, nil } // GetRemoteAddress returns the address to which the endpoint is connected. @@ -733,15 +602,11 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { - return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} + if addr, connected := e.net.GetRemoteAddress(); connected { + return addr, nil } - return tcpip.FullAddress{ - NIC: e.RegisterNICID, - Addr: e.ID.RemoteAddress, - Port: e.ID.RemotePort, - }, nil + return tcpip.FullAddress{}, &tcpip.ErrNotConnected{} } // Readiness returns the current readiness of the endpoint. For example, if @@ -766,7 +631,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // endpoint. func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. - switch e.NetProto { + switch e.net.NetProto() { case header.IPv4ProtocolNumber: h := header.ICMPv4(pkt.TransportHeader().View()) if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { @@ -840,9 +705,9 @@ func (e *endpoint) State() uint32 { // Info returns a copy of the endpoint info. func (e *endpoint) Info() tcpip.EndpointInfo { e.mu.RLock() - // Make a copy of the endpoint info. - ret := e.TransportEndpointInfo - e.mu.RUnlock() + defer e.mu.RUnlock() + ret := e.net.Info() + ret.ID.LocalPort = e.ident return &ret } diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go index b8b839e4a..dfe453ff9 100644 --- a/pkg/tcpip/transport/icmp/endpoint_state.go +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -15,11 +15,13 @@ package icmp import ( + "fmt" "time" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport" ) // saveReceivedAt is invoked by stateify. @@ -61,29 +63,24 @@ func (e *endpoint) beforeSave() { // Resume implements tcpip.ResumableEndpoint.Resume. func (e *endpoint) Resume(s *stack.Stack) { e.thaw() + + e.net.Resume(s) + e.stack = s e.ops.InitHandler(e, e.stack, tcpip.GetStackSendBufferLimits, tcpip.GetStackReceiveBufferLimits) - if e.state != stateBound && e.state != stateConnected { - return - } - - var err tcpip.Error - if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + switch state := e.net.State(); state { + case transport.DatagramEndpointStateInitial, transport.DatagramEndpointStateClosed: + case transport.DatagramEndpointStateBound, transport.DatagramEndpointStateConnected: + var err tcpip.Error + info := e.net.Info() + info.ID.LocalPort = e.ident + info.ID, err = e.registerWithStack(info.NetProto, info.ID) if err != nil { - panic(err) + panic(fmt.Sprintf("e.registerWithStack(%d, %#v): %s", info.NetProto, info.ID, err)) } - - e.ID.LocalAddress = e.route.LocalAddress() - } else if len(e.ID.LocalAddress) != 0 { // stateBound - if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { - panic(&tcpip.ErrBadLocalAddress{}) - } - } - - e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) - if err != nil { - panic(err) + e.ident = info.ID.LocalPort + default: + panic(fmt.Sprintf("unhandled state = %s", state)) } } diff --git a/pkg/tcpip/transport/internal/network/BUILD b/pkg/tcpip/transport/internal/network/BUILD index b1edce39b..3818cb04e 100644 --- a/pkg/tcpip/transport/internal/network/BUILD +++ b/pkg/tcpip/transport/internal/network/BUILD @@ -9,6 +9,7 @@ go_library( "endpoint_state.go", ], visibility = [ + "//pkg/tcpip/transport/icmp:__pkg__", "//pkg/tcpip/transport/raw:__pkg__", "//pkg/tcpip/transport/udp:__pkg__", ], diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e4a64e191..689427d53 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -67,7 +67,7 @@ type endpoint struct { waiterQueue *waiter.Queue cooked bool ops tcpip.SocketOptions - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats // The following fields are used to manage the receive queue. rcvMu sync.Mutex `state:"nosave"` diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 3040a445b..bfef75da7 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -70,7 +70,7 @@ type endpoint struct { associated bool net network.Endpoint - stats tcpip.TransportEndpointStats `state:"nosave"` + stats tcpip.TransportEndpointStats ops tcpip.SocketOptions // The following fields are used to manage the receive queue and are diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 5148fe157..20958d882 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -80,9 +80,10 @@ go_library( go_test( name = "tcp_x_test", - size = "medium", + size = "large", srcs = [ "dual_stack_test.go", + "rcv_test.go", "sack_scoreboard_test.go", "tcp_noracedetector_test.go", "tcp_rack_test.go", @@ -114,16 +115,6 @@ go_test( ) go_test( - name = "rcv_test", - size = "small", - srcs = ["rcv_test.go"], - deps = [ - "//pkg/tcpip/header", - "//pkg/tcpip/seqnum", - ], -) - -go_test( name = "tcp_test", size = "small", srcs = [ diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index ff0a5df9c..7115d0a12 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -193,14 +193,6 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true } -func (l *listenContext) useSynCookies() bool { - var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies - if err := l.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { - panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) - } - return bool(alwaysUseSynCookies) || (l.listenEP != nil && l.listenEP.synRcvdBacklogFull()) -} - // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. func (l *listenContext) createConnectingEndpoint(s *segment, rcvdSynOpts header.TCPSynOptions, queue *waiter.Queue) (*endpoint, tcpip.Error) { @@ -277,7 +269,7 @@ func (l *listenContext) startHandshake(s *segment, opts header.TCPSynOptions, qu // Propagate any inheritable options from the listening endpoint // to the newly created endpoint. - l.listenEP.propagateInheritableOptionsLocked(ep) + l.listenEP.propagateInheritableOptionsLocked(ep) // +checklocksforce if !ep.reserveTupleLocked() { ep.mu.Unlock() @@ -367,7 +359,6 @@ func (l *listenContext) closeAllPendingEndpoints() { l.pending.Wait() } -// Precondition: h.ep.mu must be held. // +checklocks:h.ep.mu func (l *listenContext) cleanupFailedHandshake(h *handshake) { e := h.ep @@ -384,7 +375,7 @@ func (l *listenContext) cleanupFailedHandshake(h *handshake) { // cleanupCompletedHandshake transfers any state from the completed handshake to // the new endpoint. // -// Precondition: h.ep.mu must be held. +// +checklocks:h.ep.mu func (l *listenContext) cleanupCompletedHandshake(h *handshake) { e := h.ep if l.listenEP != nil { @@ -404,7 +395,8 @@ func (l *listenContext) cleanupCompletedHandshake(h *handshake) { // propagateInheritableOptionsLocked propagates any options set on the listening // endpoint to the newly created endpoint. // -// Precondition: e.mu and n.mu must be held. +// +checklocks:e.mu +// +checklocks:n.mu func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { n.userTimeout = e.userTimeout n.portFlags = e.portFlags @@ -415,9 +407,9 @@ func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { // reserveTupleLocked reserves an accepted endpoint's tuple. // -// Preconditions: -// * propagateInheritableOptionsLocked has been called. -// * e.mu is held. +// Precondition: e.propagateInheritableOptionsLocked has been called. +// +// +checklocks:e.mu func (e *endpoint) reserveTupleLocked() bool { dest := tcpip.FullAddress{ Addr: e.TransportEndpointInfo.ID.RemoteAddress, @@ -459,7 +451,7 @@ func (e *endpoint) notifyAborted() { // A limited number of these goroutines are allowed before TCP starts using SYN // cookies to accept connections. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts header.TCPSynOptions) tcpip.Error { defer s.decRef() @@ -552,7 +544,7 @@ func (a *accepted) acceptQueueIsFullLocked() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. // -// Precondition: if ctx.listenEP != nil, ctx.listenEP.mu must be locked. +// +checklocks:e.mu func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Error { e.rcvQueueInfo.rcvQueueMu.Lock() rcvClosed := e.rcvQueueInfo.RcvClosed @@ -579,8 +571,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil } + alwaysUseSynCookies := func() bool { + var alwaysUseSynCookies tcpip.TCPAlwaysUseSynCookies + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &alwaysUseSynCookies); err != nil { + panic(fmt.Sprintf("TransportProtocolOption(%d, %T) = %s", header.TCPProtocolNumber, alwaysUseSynCookies, err)) + } + return bool(alwaysUseSynCookies) + }() + opts := parseSynSegmentOptions(s) - if !ctx.useSynCookies() { + if !alwaysUseSynCookies && !e.synRcvdBacklogFull() { s.incRef() atomic.AddInt32(&e.synRcvdCount, 1) return e.handleSynSegment(ctx, s, opts) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a3002abf3..407ab2664 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -2066,7 +2066,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) tcpip.Error { case *tcpip.OriginalDestinationOption: e.LockUser() ipt := e.stack.IPTables() - addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto) + addr, port, err := ipt.OriginalDst(e.TransportEndpointInfo.ID, e.NetProto, ProtocolNumber) e.UnlockUser() if err != nil { return err diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go index 8a026ec46..e47a07030 100644 --- a/pkg/tcpip/transport/tcp/rcv_test.go +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rcv_test +package tcp_test import ( "testing" diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index b355fa7eb..049957b81 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -60,9 +60,8 @@ type endpoint struct { waiterQueue *waiter.Queue uniqueID uint64 net network.Endpoint - // TODO(b/142022063): Add ability to save and restore per endpoint stats. - stats tcpip.TransportEndpointStats `state:"nosave"` - ops tcpip.SocketOptions + stats tcpip.TransportEndpointStats + ops tcpip.SocketOptions // The following fields are used to manage the receive queue, and are // protected by rcvMu. diff --git a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc index 9ae68e015..4ab4ef261 100644 --- a/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc +++ b/test/syscalls/linux/socket_ipv4_datagram_based_socket_unbound_loopback.cc @@ -18,11 +18,11 @@ namespace gvisor { namespace testing { -INSTANTIATE_TEST_SUITE_P(IPv4Sockets, IPv4DatagramBasedUnboundSocketTest, - ::testing::ValuesIn(ApplyVecToVec<SocketKind>( - {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket}, - AllBitwiseCombinations(List<int>{ - 0, SOCK_NONBLOCK})))); +INSTANTIATE_TEST_SUITE_P( + IPv4Sockets, IPv4DatagramBasedUnboundSocketTest, + ::testing::ValuesIn(ApplyVecToVec<SocketKind>( + {IPv4UDPUnboundSocket, IPv4RawUDPUnboundSocket, ICMPUnboundSocket}, + AllBitwiseCombinations(List<int>{0, SOCK_NONBLOCK})))); } // namespace testing } // namespace gvisor diff --git a/tools/bigquery/bigquery.go b/tools/bigquery/bigquery.go index 082410697..5aa1fe5dc 100644 --- a/tools/bigquery/bigquery.go +++ b/tools/bigquery/bigquery.go @@ -40,29 +40,45 @@ type Suite struct { } func (s *Suite) String() string { - conditions := make([]string, 0, len(s.Conditions)) - for _, c := range s.Conditions { - conditions = append(conditions, c.String()) - } - benchmarks := make([]string, 0, len(s.Benchmarks)) - for _, b := range s.Benchmarks { - benchmarks = append(benchmarks, b.String()) - } + var sb strings.Builder + s.debugString(&sb, "") + return sb.String() +} - format := `Suite: -Name: %s -Conditions: %s -Benchmarks: %s -Official: %t -Timestamp: %s -` +// writeLine writes a line of text to the given string builder with a prefix. +func writeLine(sb *strings.Builder, prefix string, format string, values ...interface{}) { + if prefix != "" { + sb.WriteString(prefix) + } + sb.WriteString(fmt.Sprintf(format, values...)) + sb.WriteString("\n") +} - return fmt.Sprintf(format, - s.Name, - strings.Join(conditions, "\n"), - strings.Join(benchmarks, "\n"), - s.Official, - s.Timestamp) +// debugString writes debug information to the given string builder with the +// given prefix. +func (s *Suite) debugString(sb *strings.Builder, prefix string) { + writeLine(sb, prefix, "Benchmark suite %s:", s.Name) + writeLine(sb, prefix, "Timestamp: %v", s.Timestamp) + if !s.Official { + writeLine(sb, prefix, " **** NOTE: Data is not official. **** ") + } + if numConditions := len(s.Conditions); numConditions == 0 { + writeLine(sb, prefix, "Conditions: None.") + } else { + writeLine(sb, prefix, "Conditions (%d):", numConditions) + for _, condition := range s.Conditions { + condition.debugString(sb, prefix+" ") + } + } + if numBenchmarks := len(s.Benchmarks); numBenchmarks == 0 { + writeLine(sb, prefix, "Benchmarks: None.") + } else { + writeLine(sb, prefix, "Benchmarks (%d):", numBenchmarks) + for _, benchmark := range s.Benchmarks { + benchmark.debugString(sb, prefix+" ") + } + } + sb.WriteString(fmt.Sprintf("End of data for benchmark suite %s.", s.Name)) } // Benchmark represents an individual benchmark in a suite. @@ -74,25 +90,31 @@ type Benchmark struct { // String implements the String method for Benchmark func (bm *Benchmark) String() string { - conditions := make([]string, 0, len(bm.Condition)) - for _, c := range bm.Condition { - conditions = append(conditions, c.String()) + var sb strings.Builder + bm.debugString(&sb, "") + return sb.String() +} + +// debugString writes debug information to the given string builder with the +// given prefix. +func (bm *Benchmark) debugString(sb *strings.Builder, prefix string) { + writeLine(sb, prefix, "Benchmark: %s", bm.Name) + if numConditions := len(bm.Condition); numConditions == 0 { + writeLine(sb, prefix, " Conditions: None.") + } else { + writeLine(sb, prefix, " Conditions (%d):", numConditions) + for _, condition := range bm.Condition { + condition.debugString(sb, prefix+" ") + } } - metrics := make([]string, 0, len(bm.Metric)) - for _, m := range bm.Metric { - metrics = append(metrics, m.String()) + if numMetrics := len(bm.Metric); numMetrics == 0 { + writeLine(sb, prefix, " Metrics: None.") + } else { + writeLine(sb, prefix, " Metrics (%d):", numMetrics) + for _, metric := range bm.Metric { + metric.debugString(sb, prefix+" ") + } } - - format := `Condition: -Name: %s -Conditions: %s -Metrics: %s -` - - return fmt.Sprintf(format, - bm.Name, - strings.Join(conditions, "\n"), - strings.Join(metrics, "\n")) } // AddMetric adds a metric to an existing Benchmark. @@ -107,10 +129,7 @@ func (bm *Benchmark) AddMetric(metricName, unit string, sample float64) { // AddCondition adds a condition to an existing Benchmark. func (bm *Benchmark) AddCondition(name, value string) { - bm.Condition = append(bm.Condition, &Condition{ - Name: name, - Value: value, - }) + bm.Condition = append(bm.Condition, NewCondition(name, value)) } // NewBenchmark initializes a new benchmark. @@ -136,8 +155,24 @@ type Condition struct { Value string `bq:"value"` } +// NewCondition returns a new Condition with the given name and value. +func NewCondition(name, value string) *Condition { + return &Condition{ + Name: name, + Value: value, + } +} + func (c *Condition) String() string { - return fmt.Sprintf("Condition:\nName: %s Value: %s\n", c.Name, c.Value) + var sb strings.Builder + c.debugString(&sb, "") + return sb.String() +} + +// debugString writes debug information to the given string builder with the +// given prefix. +func (c *Condition) debugString(sb *strings.Builder, prefix string) { + writeLine(sb, prefix, "Condition: %s = %s", c.Name, c.Value) } // Metric holds the actual metric data and unit information for this benchmark. @@ -148,7 +183,15 @@ type Metric struct { } func (m *Metric) String() string { - return fmt.Sprintf("Metric:\nName: %s Unit: %s Sample: %e\n", m.Name, m.Unit, m.Sample) + var sb strings.Builder + m.debugString(&sb, "") + return sb.String() +} + +// debugString writes debug information to the given string builder with the +// given prefix. +func (m *Metric) debugString(sb *strings.Builder, prefix string) { + writeLine(sb, prefix, "Metric %s: %f %s", m.Name, m.Sample, m.Unit) } // InitBigQuery initializes a BigQuery dataset/table in the project. If the dataset/table already exists, it is not duplicated. diff --git a/tools/go_fieldenum/main.go b/tools/go_fieldenum/main.go index 68dfdb3db..d801bea1b 100644 --- a/tools/go_fieldenum/main.go +++ b/tools/go_fieldenum/main.go @@ -55,6 +55,7 @@ func main() { // Determine which types are marked "+fieldenum" and will consequently have // code generated. + var typeNames []string fieldEnumTypes := make(map[string]fieldEnumTypeInfo) for _, f := range inputFiles { for _, decl := range f.Decls { @@ -75,6 +76,7 @@ func main() { if !ok { log.Fatalf("Type %s is marked +fieldenum, but is not a struct", name) } + typeNames = append(typeNames, name) fieldEnumTypes[name] = fieldEnumTypeInfo{ prefix: prefix, structType: st, @@ -86,9 +88,10 @@ func main() { } // Collect information for each type for which code is being generated. - structInfos := make([]structInfo, 0, len(fieldEnumTypes)) + structInfos := make([]structInfo, 0, len(typeNames)) needSyncAtomic := false - for typeName, typeInfo := range fieldEnumTypes { + for _, typeName := range typeNames { + typeInfo := fieldEnumTypes[typeName] var si structInfo si.name = typeName si.prefix = typeInfo.prefix @@ -204,13 +207,6 @@ func structFieldName(f *ast.Field) string { } } -// Workaround for Go defect (map membership test isn't usable in an -// expression). -func fetContains(xs map[string]*ast.StructType, x string) bool { - _, ok := xs[x] - return ok -} - func (si *structInfo) writeTo(b *strings.Builder) { fmt.Fprintf(b, "// A %sField represents a field in %s.\n", si.prefix, si.name) fmt.Fprintf(b, "type %sField uint\n\n", si.prefix) |