diff options
38 files changed, 648 insertions, 540 deletions
@@ -57,6 +57,12 @@ build_test( "//test/e2e:integration_test", "//test/image:image_test", "//test/root:root_test", + "//test/benchmarks/base:base_test", + "//test/benchmarks/database:database_test", + "//test/benchmarks/fs:fs_test", + "//test/benchmarks/media:media_test", + "//test/benchmarks/ml:ml_test", + "//test/benchmarks/network:network_test", ], ) @@ -166,14 +166,11 @@ do-tests: runsc simple-tests: unit-tests # Compatibility target. .PHONY: simple-tests -IMAGE_FILTER := HelloWorld\|Httpd\|Ruby\|Stdio -INTEGRATION_FILTER := Life\|Pause\|Connect\|JobControl\|Overlay\|Exec\|DirCreation/root - docker-tests: load-basic-images @$(call submake,install-test-runtime RUNTIME="vfs1") @$(call submake,test-runtime RUNTIME="vfs1" TARGETS="$(INTEGRATION_TARGETS)") @$(call submake,install-test-runtime RUNTIME="vfs2" ARGS="--vfs2") - @$(call submake,test-runtime RUNTIME="vfs2" OPTIONS="--test_filter=$(IMAGE_FILTER)\|$(INTEGRATION_FILTER)" TARGETS="$(INTEGRATION_TARGETS)") + @$(call submake,test-runtime RUNTIME="vfs2" TARGETS="$(INTEGRATION_TARGETS)") .PHONY: docker-tests overlay-tests: load-basic-images diff --git a/pkg/p9/client_test.go b/pkg/p9/client_test.go index c757583e0..b78fdab7a 100644 --- a/pkg/p9/client_test.go +++ b/pkg/p9/client_test.go @@ -62,6 +62,8 @@ func TestVersion(t *testing.T) { } func benchmarkSendRecv(b *testing.B, fn func(c *Client) func(message, message) error) { + b.ReportAllocs() + // See above. serverSocket, clientSocket, err := unet.SocketPair(false) if err != nil { diff --git a/pkg/p9/transport.go b/pkg/p9/transport.go index 7cec0e86d..02e665345 100644 --- a/pkg/p9/transport.go +++ b/pkg/p9/transport.go @@ -66,14 +66,17 @@ const ( var dataPool = sync.Pool{ New: func() interface{} { // These buffers are used for decoding without a payload. - return make([]byte, initialBufferLength) + // We need to return a pointer to avoid unnecessary allocations + // (see https://staticcheck.io/docs/checks#SA6002). + b := make([]byte, initialBufferLength) + return &b }, } // send sends the given message over the socket. func send(s *unet.Socket, tag Tag, m message) error { - data := dataPool.Get().([]byte) - dataBuf := buffer{data: data[:0]} + data := dataPool.Get().(*[]byte) + dataBuf := buffer{data: (*data)[:0]} if log.IsLogging(log.Debug) { log.Debugf("send [FD %d] [Tag %06d] %s", s.FD(), tag, m.String()) @@ -141,7 +144,7 @@ func send(s *unet.Socket, tag Tag, m message) error { } // All set. - dataPool.Put(dataBuf.data) + dataPool.Put(&dataBuf.data) return nil } @@ -227,12 +230,29 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, // Not yet initialized. var dataBuf buffer + var vecs [][]byte + + appendBuffer := func(size int) *[]byte { + // Pull a data buffer from the pool. + datap := dataPool.Get().(*[]byte) + data := *datap + if size > len(data) { + // Create a larger data buffer. + data = make([]byte, size) + datap = &data + } else { + // Limit the data buffer. + data = data[:size] + } + dataBuf = buffer{data: data} + vecs = append(vecs, data) + return datap + } // Read the rest of the payload. // // This requires some special care to ensure that the vectors all line // up the way they should. We do this to minimize copying data around. - var vecs [][]byte if payloader, ok := m.(payloader); ok { fixedSize := payloader.FixedSize() @@ -246,22 +266,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, } if fixedSize != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(fixedSize) > len(data) { - // Create a larger data buffer, ensuring - // sufficient capicity for the message. - data = make([]byte, fixedSize) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer, and make sure it - // gets filled before the payload buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:fixedSize]} - vecs = append(vecs, data[:fixedSize]) - } + datap := appendBuffer(int(fixedSize)) + defer dataPool.Put(datap) } // Include the payload. @@ -274,20 +280,8 @@ func recv(s *unet.Socket, msize uint32, lookup lookupTagAndType) (Tag, message, vecs = append(vecs, p) } } else if remaining != 0 { - // Pull a data buffer from the pool. - data := dataPool.Get().([]byte) - if int(remaining) > len(data) { - // Create a larger data buffer. - data = make([]byte, remaining) - defer dataPool.Put(data) - dataBuf = buffer{data: data} - vecs = append(vecs, data) - } else { - // Limit the data buffer. - defer dataPool.Put(data) - dataBuf = buffer{data: data[:remaining]} - vecs = append(vecs, data[:remaining]) - } + datap := appendBuffer(int(remaining)) + defer dataPool.Put(datap) } if len(vecs) > 0 { diff --git a/pkg/p9/transport_test.go b/pkg/p9/transport_test.go index 3668fcad7..e7406b374 100644 --- a/pkg/p9/transport_test.go +++ b/pkg/p9/transport_test.go @@ -182,6 +182,8 @@ func TestSendClosed(t *testing.T) { } func BenchmarkSendRecv(b *testing.B) { + b.ReportAllocs() + server, client, err := unet.SocketPair(false) if err != nil { b.Fatalf("socketpair got err %v expected nil", err) diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 3f39edb66..d9d5e6bcb 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -475,3 +475,13 @@ func (r *AtomicRefCount) DecRefWithDestructor(ctx context.Context, destroy func( func (r *AtomicRefCount) DecRef(ctx context.Context) { r.DecRefWithDestructor(ctx, nil) } + +// OnExit is called on sandbox exit. It runs GC to enqueue refcount finalizers, +// which check for reference leaks. There is no way to guarantee that every +// finalizer will run before exiting, but this at least ensures that they will +// be discovered/enqueued by GC. +func OnExit() { + if LeakMode(atomic.LoadUint32(&leakMode)) != NoLeakChecking { + runtime.GC() + } +} diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index bf922c566..bd6caba06 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -552,7 +552,7 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin return syserror.ESPIPE } - // TODO(gvisor.dev/issue/2923): Implement Allocate for non-pipe hostfds. + // TODO(gvisor.dev/issue/3589): Implement Allocate for non-pipe hostfds. return syserror.EOPNOTSUPP } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index a9f0604ae..e91b0624c 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -32,15 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" - -// redirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const redirectTargetName = "REDIRECT" - // enableLogging controls whether to log the (de)serialization of netfilter // structs between userspace and netstack. These logs are useful when // developing iptables, but can pollute sentry logs otherwise. @@ -202,130 +193,6 @@ func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (lin return entries, info, nil } -func marshalTarget(target stack.Target) []byte { - switch tg := target.(type) { - case stack.AcceptTarget: - return marshalStandardTarget(stack.RuleAccept) - case stack.DropTarget: - return marshalStandardTarget(stack.RuleDrop) - case stack.ErrorTarget: - return marshalErrorTarget(errorTargetName) - case stack.UserChainTarget: - return marshalErrorTarget(tg.Name) - case stack.ReturnTarget: - return marshalStandardTarget(stack.RuleReturn) - case stack.RedirectTarget: - return marshalRedirectTarget(tg) - case JumpTarget: - return marshalJumpTarget(tg) - default: - panic(fmt.Errorf("unknown target of type %T", target)) - } -} - -func marshalStandardTarget(verdict stack.RuleVerdict) []byte { - nflog("convert to binary: marshalling standard target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - Verdict: translateFromStandardVerdict(verdict), - } - - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalErrorTarget(errorName string) []byte { - // This is an error target named error - target := linux.XTErrorTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTErrorTarget, - }, - } - copy(target.Name[:], errorName) - copy(target.Target.Name[:], errorTargetName) - - ret := make([]byte, 0, linux.SizeOfXTErrorTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalRedirectTarget(rt stack.RedirectTarget) []byte { - // This is a redirect target named redirect - target := linux.XTRedirectTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTRedirectTarget, - }, - } - copy(target.Target.Name[:], redirectTargetName) - - ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) - target.NfRange.RangeSize = 1 - if rt.RangeProtoSpecified { - target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED - } - // Convert port from little endian to big endian. - port := make([]byte, 2) - binary.LittleEndian.PutUint16(port, rt.MinPort) - target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) - binary.LittleEndian.PutUint16(port, rt.MaxPort) - target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -func marshalJumpTarget(jt JumpTarget) []byte { - nflog("convert to binary: marshalling jump target") - - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - // Verdict is overloaded by the ABI. When positive, it holds - // the jump offset from the start of the table. - Verdict: int32(jt.Offset), - } - - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) -} - -// translateFromStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { - switch verdict { - case stack.RuleAccept: - return -linux.NF_ACCEPT - 1 - case stack.RuleDrop: - return -linux.NF_DROP - 1 - case stack.RuleReturn: - return linux.NF_RETURN - default: - // TODO(gvisor.dev/issue/170): Support Jump. - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) - } -} - -// translateToStandardTarget translates from the value in a -// linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32) (stack.Target, error) { - // TODO(gvisor.dev/issue/170): Support other verdicts. - switch val { - case -linux.NF_ACCEPT - 1: - return stack.AcceptTarget{}, nil - case -linux.NF_DROP - 1: - return stack.DropTarget{}, nil - case -linux.NF_QUEUE - 1: - return nil, errors.New("unsupported iptables verdict QUEUE") - case linux.NF_RETURN: - return stack.ReturnTarget{}, nil - default: - return nil, fmt.Errorf("unknown iptables verdict %d", val) - } -} - // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { @@ -562,113 +429,6 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return matchers, nil } -// parseTarget parses a target from optVal. optVal should contain only the -// target. -func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { - nflog("set entries: parsing target of size %d", len(optVal)) - if len(optVal) < linux.SizeOfXTEntryTarget { - return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) - } - var target linux.XTEntryTarget - buf := optVal[:linux.SizeOfXTEntryTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &target) - switch target.Name.String() { - case "": - // Standard target. - if len(optVal) != linux.SizeOfXTStandardTarget { - return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) - } - var standardTarget linux.XTStandardTarget - buf = optVal[:linux.SizeOfXTStandardTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) - - if standardTarget.Verdict < 0 { - // A Verdict < 0 indicates a non-jump verdict. - return translateToStandardTarget(standardTarget.Verdict) - } - // A verdict >= 0 indicates a jump. - return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil - - case errorTargetName: - // Error target. - if len(optVal) != linux.SizeOfXTErrorTarget { - return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) - } - var errorTarget linux.XTErrorTarget - buf = optVal[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) - - // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named errorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. - // * To mark the start of a user defined chain. These - // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case errorTargetName: - nflog("set entries: error target") - return stack.ErrorTarget{}, nil - default: - // User defined chain. - nflog("set entries: user-defined target %q", name) - return stack.UserChainTarget{Name: name}, nil - } - - case redirectTargetName: - // Redirect target. - if len(optVal) < linux.SizeOfXTRedirectTarget { - return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) - } - - if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - var redirectTarget linux.XTRedirectTarget - buf = optVal[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) - - // Copy linux.XTRedirectTarget to stack.RedirectTarget. - var target stack.RedirectTarget - nfRange := redirectTarget.NfRange - - // RangeSize should be 1. - if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // TODO(gvisor.dev/issue/170): Check if the flags are valid. - // Also check if we need to map ports or IP. - // For now, redirect target only supports destination port change. - // Port range and IP range are not supported yet. - if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - target.RangeProtoSpecified = true - - target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) - target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) - - // TODO(gvisor.dev/issue/170): Port range is not supported yet. - if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") - } - - // Convert port from big endian to little endian. - port := make([]byte, 2) - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) - target.MinPort = binary.LittleEndian.Uint16(port) - - binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) - target.MaxPort = binary.LittleEndian.Uint16(port) - return target, nil - } - - // Unknown target. - return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String()) -} - func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { if containsUnsupportedFields(iptip) { return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b91ba3ab3..8ebdaff18 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -15,10 +15,257 @@ package netfilter import ( + "errors" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) +// errorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const errorTargetName = "ERROR" + +// redirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const redirectTargetName = "REDIRECT" + +func marshalTarget(target stack.Target) []byte { + switch tg := target.(type) { + case stack.AcceptTarget: + return marshalStandardTarget(stack.RuleAccept) + case stack.DropTarget: + return marshalStandardTarget(stack.RuleDrop) + case stack.ErrorTarget: + return marshalErrorTarget(errorTargetName) + case stack.UserChainTarget: + return marshalErrorTarget(tg.Name) + case stack.ReturnTarget: + return marshalStandardTarget(stack.RuleReturn) + case stack.RedirectTarget: + return marshalRedirectTarget(tg) + case JumpTarget: + return marshalJumpTarget(tg) + default: + panic(fmt.Errorf("unknown target of type %T", target)) + } +} + +func marshalStandardTarget(verdict stack.RuleVerdict) []byte { + nflog("convert to binary: marshalling standard target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + Verdict: translateFromStandardVerdict(verdict), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalErrorTarget(errorName string) []byte { + // This is an error target named error + target := linux.XTErrorTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTErrorTarget, + }, + } + copy(target.Name[:], errorName) + copy(target.Target.Name[:], errorTargetName) + + ret := make([]byte, 0, linux.SizeOfXTErrorTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalRedirectTarget(rt stack.RedirectTarget) []byte { + // This is a redirect target named redirect + target := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + copy(target.Target.Name[:], redirectTargetName) + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + target.NfRange.RangeSize = 1 + if rt.RangeProtoSpecified { + target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + } + // Convert port from little endian to big endian. + port := make([]byte, 2) + binary.LittleEndian.PutUint16(port, rt.MinPort) + target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) + binary.LittleEndian.PutUint16(port, rt.MaxPort) + target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalJumpTarget(jt JumpTarget) []byte { + nflog("convert to binary: marshalling jump target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + // Verdict is overloaded by the ABI. When positive, it holds + // the jump offset from the start of the table. + Verdict: int32(jt.Offset), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +// translateFromStandardVerdict translates verdicts the same way as the iptables +// tool. +func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { + switch verdict { + case stack.RuleAccept: + return -linux.NF_ACCEPT - 1 + case stack.RuleDrop: + return -linux.NF_DROP - 1 + case stack.RuleReturn: + return linux.NF_RETURN + default: + // TODO(gvisor.dev/issue/170): Support Jump. + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) + } +} + +// translateToStandardTarget translates from the value in a +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32) (stack.Target, error) { + // TODO(gvisor.dev/issue/170): Support other verdicts. + switch val { + case -linux.NF_ACCEPT - 1: + return stack.AcceptTarget{}, nil + case -linux.NF_DROP - 1: + return stack.DropTarget{}, nil + case -linux.NF_QUEUE - 1: + return nil, errors.New("unsupported iptables verdict QUEUE") + case linux.NF_RETURN: + return stack.ReturnTarget{}, nil + default: + return nil, fmt.Errorf("unknown iptables verdict %d", val) + } +} + +// parseTarget parses a target from optVal. optVal should contain only the +// target. +func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { + nflog("set entries: parsing target of size %d", len(optVal)) + if len(optVal) < linux.SizeOfXTEntryTarget { + return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) + } + var target linux.XTEntryTarget + buf := optVal[:linux.SizeOfXTEntryTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + switch target.Name.String() { + case "": + // Standard target. + if len(optVal) != linux.SizeOfXTStandardTarget { + return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) + } + var standardTarget linux.XTStandardTarget + buf = optVal[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict) + } + // A verdict >= 0 indicates a jump. + return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil + + case errorTargetName: + // Error target. + if len(optVal) != linux.SizeOfXTErrorTarget { + return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) + } + var errorTarget linux.XTErrorTarget + buf = optVal[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named errorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case errorTargetName: + nflog("set entries: error target") + return stack.ErrorTarget{}, nil + default: + // User defined chain. + nflog("set entries: user-defined target %q", name) + return stack.UserChainTarget{Name: name}, nil + } + + case redirectTargetName: + // Redirect target. + if len(optVal) < linux.SizeOfXTRedirectTarget { + return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) + } + + if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + var redirectTarget linux.XTRedirectTarget + buf = optVal[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + var target stack.RedirectTarget + nfRange := redirectTarget.NfRange + + // RangeSize should be 1. + if nfRange.RangeSize != 1 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + target.RangeProtoSpecified = true + + target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // Convert port from big endian to little endian. + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) + target.MinPort = binary.LittleEndian.Uint16(port) + + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) + target.MaxPort = binary.LittleEndian.Uint16(port) + return target, nil + } + + // Unknown target. + return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String()) +} + // JumpTarget implements stack.Target. type JumpTarget struct { // Offset is the byte offset of the rule to jump to. It is used for diff --git a/pkg/sentry/syscalls/linux/vfs2/aio.go b/pkg/sentry/syscalls/linux/vfs2/aio.go index 399b4f60c..42559bf69 100644 --- a/pkg/sentry/syscalls/linux/vfs2/aio.go +++ b/pkg/sentry/syscalls/linux/vfs2/aio.go @@ -144,6 +144,12 @@ func submitCallback(t *kernel.Task, id uint64, cb *linux.IOCallback, cbAddr user func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr usermem.Addr, cb *linux.IOCallback, ioseq usermem.IOSequence, aioCtx *mm.AIOContext) kernel.AIOCallback { return func(ctx context.Context) { + // Release references after completing the callback. + defer fd.DecRef(ctx) + if eventFD != nil { + defer eventFD.DecRef(ctx) + } + if aioCtx.Dead() { aioCtx.CancelPendingRequest() return @@ -169,8 +175,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use ev.Result = -int64(kernel.ExtractErrno(err, 0)) } - fd.DecRef(ctx) - // Queue the result for delivery. aioCtx.FinishRequest(ev) @@ -179,7 +183,6 @@ func getAIOCallback(t *kernel.Task, fd, eventFD *vfs.FileDescription, cbAddr use // wake up. if eventFD != nil { eventFD.Impl().(*eventfd.EventFileDescription).Signal(1) - eventFD.DecRef(ctx) } } } diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index d3c1197e3..dcafffe57 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -289,7 +289,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede if flags&linux.O_DIRECT != 0 && !fd.opts.AllowDirectIO { return syserror.EINVAL } - // TODO(jamieliu): FileDescriptionImpl.SetOAsync()? + // TODO(gvisor.dev/issue/1035): FileDescriptionImpl.SetOAsync()? const settableFlags = linux.O_APPEND | linux.O_ASYNC | linux.O_DIRECT | linux.O_NOATIME | linux.O_NONBLOCK fd.flagsMu.Lock() if fd.asyncHandler != nil { @@ -301,7 +301,7 @@ func (fd *FileDescription) SetStatusFlags(ctx context.Context, creds *auth.Crede fd.asyncHandler.Unregister(fd) } } - fd.statusFlags = (oldFlags &^ settableFlags) | (flags & settableFlags) + atomic.StoreUint32(&fd.statusFlags, (oldFlags&^settableFlags)|(flags&settableFlags)) fd.flagsMu.Unlock() return nil } diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD index ed434807f..c984470e6 100644 --- a/pkg/tcpip/checker/BUILD +++ b/pkg/tcpip/checker/BUILD @@ -12,5 +12,6 @@ go_library( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/seqnum", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index ee264b726..1e5f5abf2 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -21,6 +21,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -169,10 +170,9 @@ func ReceiveTClass(want uint32) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTClass { - t.Fatalf("got cm.HasTClass = %t, want cm.TClass = %d", cm.HasTClass, want) - } - if got := cm.TClass; got != want { - t.Fatalf("got cm.TClass = %d, want %d", got, want) + t.Errorf("got cm.HasTClass = %t, want = true", cm.HasTClass) + } else if got := cm.TClass; got != want { + t.Errorf("got cm.TClass = %d, want %d", got, want) } } } @@ -182,10 +182,22 @@ func ReceiveTOS(want uint8) ControlMessagesChecker { return func(t *testing.T, cm tcpip.ControlMessages) { t.Helper() if !cm.HasTOS { - t.Fatalf("got cm.HasTOS = %t, want cm.TOS = %d", cm.HasTOS, want) + t.Errorf("got cm.HasTOS = %t, want = true", cm.HasTOS) + } else if got := cm.TOS; got != want { + t.Errorf("got cm.TOS = %d, want %d", got, want) } - if got := cm.TOS; got != want { - t.Fatalf("got cm.TOS = %d, want %d", got, want) + } +} + +// ReceiveIPPacketInfo creates a checker that checks the PacketInfo field in +// ControlMessages. +func ReceiveIPPacketInfo(want tcpip.IPPacketInfo) ControlMessagesChecker { + return func(t *testing.T, cm tcpip.ControlMessages) { + t.Helper() + if !cm.HasIPPacketInfo { + t.Errorf("got cm.HasIPPacketInfo = %t, want = true", cm.HasIPPacketInfo) + } else if diff := cmp.Diff(want, cm.PacketInfo); diff != "" { + t.Errorf("IPPacketInfo mismatch (-want +got):\n%s", diff) } } } diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index 467083239..fc1e34fc7 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -199,7 +199,7 @@ func (e *endpoint) WritePackets(_ *stack.Route, _ *stack.GSO, pkts stack.PacketB // WriteRawPacket implements stack.LinkEndpoint.WriteRawPacket. func (e *endpoint) WriteRawPacket(vv buffer.VectorisedView) *tcpip.Error { - // TODO(gvisor.dev/issue/3267/): Queue these packets as well once + // TODO(gvisor.dev/issue/3267): Queue these packets as well once // WriteRawPacket takes PacketBuffer instead of VectorisedView. return e.lower.WriteRawPacket(vv) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6c4f0ae3e..9ff27a363 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -173,9 +173,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayload := pkt.Data.Clone(nil) newPayload.CapLength(innerMTU) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -192,9 +193,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, newPayloadLength := outerMTU - pkt.Header.UsedLength() newPayload.CapLength(newPayloadLength) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: pkt.Header, - Data: newPayload, - NetworkHeader: buffer.View(h), + Header: pkt.Header, + Data: newPayload, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -206,9 +208,10 @@ func (e *endpoint) writePacketFragments(r *stack.Route, gso *stack.GSO, mtu int, startOfHdr.TrimBack(pkt.Header.UsedLength() - outerMTU) emptyVV := buffer.NewVectorisedView(0, []buffer.View{}) if err := e.linkEP.WritePacket(r, gso, ProtocolNumber, &stack.PacketBuffer{ - Header: startOfHdr, - Data: emptyVV, - NetworkHeader: buffer.View(h), + Header: startOfHdr, + Data: emptyVV, + NetworkHeader: buffer.View(h), + NetworkProtocolNumber: header.IPv4ProtocolNumber, }); err != nil { return err } @@ -249,10 +252,11 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber - nicName := e.stack.FindNICNameFromID(e.NICID()) // iptables filtering. All packets that reach here are locally // generated. + nicName := e.stack.FindNICNameFromID(e.NICID()) ipt := e.stack.IPTables() if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. @@ -304,6 +308,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt = pkt.Next() } @@ -570,6 +575,7 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu parseTransportHeader = false } + pkt.NetworkProtocolNumber = header.IPv4ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.TotalLength()) - len(hdr)) diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ded97ac64..63e2c36c2 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -150,6 +150,9 @@ func compareFragments(t *testing.T, packets []*stack.PacketBuffer, sourcePacketI if got, want := packet.Header.AvailableLength(), sourcePacketInfo.Header.AvailableLength()-header.IPv4MinimumSize; got != want { t.Errorf("fragment #%d should have the same available space for prepending as source: got %d, want %d", i, got, want) } + if got, want := packet.NetworkProtocolNumber, sourcePacketInfo.NetworkProtocolNumber; got != want { + t.Errorf("fragment #%d has wrong network protocol number: got %d, want %d", i, got, want) + } if i < len(packets)-1 { sourceCopy.SetFlagsFragmentOffset(sourceCopy.Flags()|header.IPv4FlagMoreFragments, offset) } else { @@ -285,7 +288,8 @@ func TestFragmentation(t *testing.T) { source := &stack.PacketBuffer{ Header: hdr, // Save the source payload because WritePacket will modify it. - Data: payload.Clone(nil), + Data: payload.Clone(nil), + NetworkProtocolNumber: header.IPv4ProtocolNumber, } c := buildContext(t, nil, ft.mtu) err := c.Route.WritePacket(ft.gso, stack.NetworkHeaderParams{ diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 4a0b53c45..d7d7fc611 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -117,6 +117,7 @@ func (e *endpoint) addIPHeader(r *stack.Route, hdr *buffer.Prependable, payloadS func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.NetworkHeaderParams, pkt *stack.PacketBuffer) *tcpip.Error { ip := e.addIPHeader(r, &pkt.Header, pkt.Data.Size(), params) pkt.NetworkHeader = buffer.View(ip) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber if r.Loop&stack.PacketLoop != 0 { // The inbound path expects the network header to still be in @@ -152,6 +153,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pb := pkts.Front(); pb != nil; pb = pb.Next() { ip := e.addIPHeader(r, &pb.Header, pb.Data.Size(), params) pb.NetworkHeader = buffer.View(ip) + pb.NetworkProtocolNumber = header.IPv6ProtocolNumber } n, err := e.linkEP.WritePackets(r, gso, pkts, ProtocolNumber) @@ -586,6 +588,7 @@ traverseExtensions: } ipHdr = header.IPv6(hdr) + pkt.NetworkProtocolNumber = header.IPv6ProtocolNumber pkt.NetworkHeader = hdr pkt.Data.TrimFront(len(hdr)) pkt.Data.CapLength(int(ipHdr.PayloadLength())) diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 5d6865e35..9e871f968 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -62,6 +62,11 @@ type PacketBuffer struct { NetworkHeader buffer.View TransportHeader buffer.View + // NetworkProtocol is only valid when NetworkHeader is set. + // TODO(gvisor.dev/issue/3574): Remove the separately passed protocol + // numbers in registration APIs that take a PacketBuffer. + NetworkProtocolNumber tcpip.NetworkProtocolNumber + // Hash is the transport layer hash of this packet. A value of zero // indicates no valid hash has been set. Hash uint32 @@ -72,9 +77,8 @@ type PacketBuffer struct { // The following fields are only set by the qdisc layer when the packet // is added to a queue. - EgressRoute *Route - GSOOptions *GSO - NetworkProtocolNumber tcpip.NetworkProtocolNumber + EgressRoute *Route + GSOOptions *GSO // NatDone indicates if the packet has been manipulated as per NAT // iptables rule. diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 8604c4259..4570e8969 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -249,8 +249,8 @@ type NetworkEndpoint interface { MaxHeaderLength() uint16 // WritePacket writes a packet to the given destination address and - // protocol. It takes ownership of pkt. pkt.TransportHeader must have already - // been set. + // protocol. It takes ownership of pkt. pkt.TransportHeader must have + // already been set. WritePacket(r *Route, gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error // WritePackets writes packets to the given destination address and diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 45f59b60f..091bc5281 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -968,7 +968,7 @@ type IPPacketInfo struct { // LocalAddr is the local address. LocalAddr Address - // DestinationAddr is the destination address. + // DestinationAddr is the destination address found in the IP header. DestinationAddr Address } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 444b5b01c..4a2b6c03a 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1444,13 +1444,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk switch r.NetProto { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.RemoteAddress - packet.packetInfo.NIC = r.NICID() case header.IPv6ProtocolNumber: packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() } + // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast + // address. packetInfo.LocalAddr should hold a unicast address that can be + // used to respond to the incoming packet. + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.LocalAddress + packet.packetInfo.NIC = r.NICID() packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 66e8911c8..1a32622ca 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1309,6 +1309,105 @@ func TestReadIncrementsPacketsReceived(t *testing.T) { } } +func TestReadIPPacketInfo(t *testing.T) { + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + flow testFlow + expectedLocalAddr tcpip.Address + expectedDestAddr tcpip.Address + }{ + { + name: "IPv4 unicast", + proto: header.IPv4ProtocolNumber, + flow: unicastV4, + expectedLocalAddr: stackAddr, + expectedDestAddr: stackAddr, + }, + { + name: "IPv4 multicast", + proto: header.IPv4ProtocolNumber, + flow: multicastV4, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastAddr, + expectedDestAddr: multicastAddr, + }, + { + name: "IPv4 broadcast", + proto: header.IPv4ProtocolNumber, + flow: broadcast, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: broadcastAddr, + expectedDestAddr: broadcastAddr, + }, + { + name: "IPv6 unicast", + proto: header.IPv6ProtocolNumber, + flow: unicastV6, + expectedLocalAddr: stackV6Addr, + expectedDestAddr: stackV6Addr, + }, + { + name: "IPv6 multicast", + proto: header.IPv6ProtocolNumber, + flow: multicastV6, + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + expectedLocalAddr: multicastV6Addr, + expectedDestAddr: multicastV6Addr, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(test.proto) + + bindAddr := tcpip.FullAddress{Port: stackPort} + if err := c.ep.Bind(bindAddr); err != nil { + t.Fatalf("Bind(%+v): %s", bindAddr, err) + } + + if test.flow.isMulticast() { + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: test.flow.getMcastAddr()} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt(%+v): %s:", ifoptSet, err) + } + } + + if err := c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true); err != nil { + t.Fatalf("c.ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, true): %s", err) + } + + testRead(c, test.flow, checker.ReceiveIPPacketInfo(tcpip.IPPacketInfo{ + NIC: 1, + LocalAddr: test.expectedLocalAddr, + DestinationAddr: test.expectedDestAddr, + })) + + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != 1 { + t.Fatalf("Read did not increment PacketsReceived: got = %d, want = 1", got) + } + }) + } +} + func TestWriteIncrementsPacketsSent(t *testing.T) { c := newDualTestContext(t, defaultMTU) defer c.cleanup() diff --git a/pkg/test/dockerutil/dockerutil.go b/pkg/test/dockerutil/dockerutil.go index 5a9dd8bd8..952871f95 100644 --- a/pkg/test/dockerutil/dockerutil.go +++ b/pkg/test/dockerutil/dockerutil.go @@ -88,44 +88,74 @@ func EnsureSupportedDockerVersion() { // RuntimePath returns the binary path for the current runtime. func RuntimePath() (string, error) { + rs, err := runtimeMap() + if err != nil { + return "", err + } + + p, ok := rs["path"].(string) + if !ok { + // The runtime does not declare a path. + return "", fmt.Errorf("runtime does not declare a path: %v", rs) + } + return p, nil +} + +// UsingVFS2 returns true if the 'runtime' has the vfs2 flag set. +// TODO(gvisor.dev/issue/1624): Remove. +func UsingVFS2() (bool, error) { + rMap, err := runtimeMap() + if err != nil { + return false, err + } + + list, ok := rMap["runtimeArgs"].([]interface{}) + if !ok { + return false, fmt.Errorf("unexpected format: %v", rMap) + } + + for _, element := range list { + if element == "--vfs2" { + return true, nil + } + } + return false, nil +} + +func runtimeMap() (map[string]interface{}, error) { // Read the configuration data; the file must exist. configBytes, err := ioutil.ReadFile(*config) if err != nil { - return "", err + return nil, err } // Unmarshal the configuration. c := make(map[string]interface{}) if err := json.Unmarshal(configBytes, &c); err != nil { - return "", err + return nil, err } // Decode the expected configuration. r, ok := c["runtimes"] if !ok { - return "", fmt.Errorf("no runtimes declared: %v", c) + return nil, fmt.Errorf("no runtimes declared: %v", c) } rs, ok := r.(map[string]interface{}) if !ok { // The runtimes are not a map. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", rs) } r, ok = rs[*runtime] if !ok { // The expected runtime is not declared. - return "", fmt.Errorf("runtime %q not found: %v", *runtime, c) + return nil, fmt.Errorf("runtime %q not found: %v", *runtime, rs) } rs, ok = r.(map[string]interface{}) if !ok { // The runtime is not a map. - return "", fmt.Errorf("unexpected format: %v", c) - } - p, ok := rs["path"].(string) - if !ok { - // The runtime does not declare a path. - return "", fmt.Errorf("unexpected format: %v", c) + return nil, fmt.Errorf("unexpected format: %v", r) } - return p, nil + return rs, nil } // Save exports a container image to the given Writer. diff --git a/runsc/boot/loader.go b/runsc/boot/loader.go index 533b9c5e7..40c6f99fd 100644 --- a/runsc/boot/loader.go +++ b/runsc/boot/loader.go @@ -32,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/memutil" "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/control" "gvisor.dev/gvisor/pkg/sentry/fdimport" @@ -1011,6 +1012,8 @@ func (l *Loader) WaitExit() kernel.ExitStatus { // Cleanup l.ctrl.stop() + refs.OnExit() + return l.k.GlobalInit().ExitStatus() } diff --git a/scripts/docker_tests.sh b/scripts/docker_tests.sh index be0b0a3ec..4f3867d05 100755 --- a/scripts/docker_tests.sh +++ b/scripts/docker_tests.sh @@ -22,6 +22,4 @@ install_runsc_for_test docker test_runsc //test/image:image_test //test/e2e:integration_test install_runsc_for_test docker --vfs2 -IMAGE_FILTER="Hello|Httpd|Ruby|Stdio" -INTEGRATION_FILTER="LifeCycle|Pause|Connect|JobControl|Overlay|Exec|DirCreation/root|Link" -test_runsc //test/e2e:integration_test //test/image:image_test --test_filter="${IMAGE_FILTER}|${INTEGRATION_FILTER}" +test_runsc //test/e2e:integration_test //test/image:image_test diff --git a/test/benchmarks/base/BUILD b/test/benchmarks/base/BUILD index 5e099d0f9..32c139204 100644 --- a/test/benchmarks/base/BUILD +++ b/test/benchmarks/base/BUILD @@ -25,6 +25,7 @@ go_test( "manual", "local", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/database/BUILD b/test/benchmarks/database/BUILD index 6139f6e8a..93b380e8a 100644 --- a/test/benchmarks/database/BUILD +++ b/test/benchmarks/database/BUILD @@ -19,6 +19,7 @@ go_test( "manual", "local", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/fs/BUILD b/test/benchmarks/fs/BUILD index 20654d88f..45f11372b 100644 --- a/test/benchmarks/fs/BUILD +++ b/test/benchmarks/fs/BUILD @@ -22,6 +22,7 @@ go_test( "local", "manual", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/media/BUILD b/test/benchmarks/media/BUILD index 6c41fc4f6..bb242d385 100644 --- a/test/benchmarks/media/BUILD +++ b/test/benchmarks/media/BUILD @@ -14,6 +14,7 @@ go_test( size = "large", srcs = ["ffmpeg_test.go"], library = ":media", + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/ml/BUILD b/test/benchmarks/ml/BUILD index 2430b60a7..970f52706 100644 --- a/test/benchmarks/ml/BUILD +++ b/test/benchmarks/ml/BUILD @@ -14,6 +14,7 @@ go_test( size = "large", srcs = ["tensorflow_test.go"], library = ":ml", + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//test/benchmarks/harness", diff --git a/test/benchmarks/network/BUILD b/test/benchmarks/network/BUILD index df5ff7265..bd3f6245c 100644 --- a/test/benchmarks/network/BUILD +++ b/test/benchmarks/network/BUILD @@ -25,6 +25,7 @@ go_test( "manual", "local", ], + visibility = ["//:sandbox"], deps = [ "//pkg/test/dockerutil", "//pkg/test/testutil", diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 71ec4791e..809244bab 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -167,6 +167,13 @@ func TestCheckpointRestore(t *testing.T) { t.Skip("Pause/resume is not supported.") } + // TODO(gvisor.dev/issue/3373): Remove after implementing. + if usingVFS2, err := dockerutil.UsingVFS2(); usingVFS2 { + t.Skip("CheckpointRestore not implemented in VFS2.") + } else if err != nil { + t.Fatalf("failed to read config for runtime %s: %v", dockerutil.Runtime(), err) + } + ctx := context.Background() d := dockerutil.MakeContainer(ctx, t) defer d.CleanUp(ctx) diff --git a/test/fuse/linux/fuse_base.cc b/test/fuse/linux/fuse_base.cc index 4a2c64998..9c3124472 100644 --- a/test/fuse/linux/fuse_base.cc +++ b/test/fuse/linux/fuse_base.cc @@ -25,8 +25,8 @@ #include <iostream> -#include "absl/strings/str_format.h" #include "gtest/gtest.h" +#include "absl/strings/str_format.h" #include "test/util/posix_error.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" diff --git a/test/runner/defs.bzl b/test/runner/defs.bzl index 1ae13a4a5..2d64934b0 100644 --- a/test/runner/defs.bzl +++ b/test/runner/defs.bzl @@ -132,7 +132,7 @@ def syscall_test( add_overlay = False, add_uds_tree = False, add_hostinet = False, - vfs2 = False, + vfs2 = True, fuse = False, tags = None): """syscall_test is a macro that will create targets for all platforms. diff --git a/test/runner/gtest/gtest.go b/test/runner/gtest/gtest.go index 869169ad5..e4445e01b 100644 --- a/test/runner/gtest/gtest.go +++ b/test/runner/gtest/gtest.go @@ -146,10 +146,13 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes return nil, fmt.Errorf("could not enumerate gtest benchmarks: %v\nstderr\n%s", err, exitErr.Stderr) } - out = []byte(strings.Trim(string(out), "\n")) + benches := strings.Trim(string(out), "\n") + if len(benches) == 0 { + return t, nil + } // Parse benchmark output. - for _, line := range strings.Split(string(out), "\n") { + for _, line := range strings.Split(benches, "\n") { // Strip comments. line = strings.Split(line, "#")[0] @@ -163,6 +166,5 @@ func ParseTestCases(testBin string, benchmarks bool, extraArgs ...string) ([]Tes benchmark: true, }) } - return t, nil } diff --git a/test/runner/runner.go b/test/runner/runner.go index bc4b39cbb..5ac91310d 100644 --- a/test/runner/runner.go +++ b/test/runner/runner.go @@ -172,13 +172,14 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { args = append(args, "-fsgofer-host-uds") } - undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR") - if ok { - tdir := filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) - if err := os.MkdirAll(tdir, 0755); err != nil { + testLogDir := "" + if undeclaredOutputsDir, ok := syscall.Getenv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { + // Create log directory dedicated for this test. + testLogDir = filepath.Join(undeclaredOutputsDir, strings.Replace(name, "/", "_", -1)) + if err := os.MkdirAll(testLogDir, 0755); err != nil { return fmt.Errorf("could not create test dir: %v", err) } - debugLogDir, err := ioutil.TempDir(tdir, "runsc") + debugLogDir, err := ioutil.TempDir(testLogDir, "runsc") if err != nil { return fmt.Errorf("could not create temp dir: %v", err) } @@ -227,10 +228,10 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs := append([]string{}, args...) dArgs = append(dArgs, "-alsologtostderr=true", "debug", "--stacks", id) go func(dArgs []string) { - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + debug := exec.Command(*runscPath, dArgs...) + debug.Stdout = os.Stdout + debug.Stderr = os.Stderr + debug.Run() done <- true }(dArgs) @@ -245,17 +246,17 @@ func runRunsc(tc gtest.TestCase, spec *specs.Spec) error { dArgs = append(args, "debug", fmt.Sprintf("--signal=%d", syscall.SIGTERM), id) - cmd := exec.Command(*runscPath, dArgs...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - cmd.Run() + signal := exec.Command(*runscPath, dArgs...) + signal.Stdout = os.Stdout + signal.Stderr = os.Stderr + signal.Run() }() err = cmd.Run() - if err == nil { + if err == nil && len(testLogDir) > 0 { // If the test passed, then we erase the log directory. This speeds up // uploading logs in continuous integration & saves on disk space. - os.RemoveAll(undeclaredOutputsDir) + os.RemoveAll(testLogDir) } return err diff --git a/test/syscalls/BUILD b/test/syscalls/BUILD index a31612b41..571785ad2 100644 --- a/test/syscalls/BUILD +++ b/test/syscalls/BUILD @@ -4,75 +4,62 @@ package(licenses = ["notice"]) syscall_test( test = "//test/syscalls/linux:32bit_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:accept_bind_stream_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:accept_bind_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:access_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:affinity_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:aio_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:alarm_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:arch_prctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:bad_test", - vfs2 = "True", ) syscall_test( size = "large", add_overlay = True, test = "//test/syscalls/linux:bind_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:brk_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_capability_test", - vfs2 = "True", ) syscall_test( @@ -82,19 +69,16 @@ syscall_test( # involve much concurrency, TSAN's usefulness here is limited anyway. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_stress_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chdir_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chmod_test", - vfs2 = "True", ) syscall_test( @@ -102,116 +86,96 @@ syscall_test( add_overlay = True, test = "//test/syscalls/linux:chown_test", use_tmpfs = True, # chwon tests require gofer to be running as root. - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:chroot_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_getres_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:clock_nanosleep_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:concurrency_test", - vfs2 = "True", ) syscall_test( add_uds_tree = True, test = "//test/syscalls/linux:connect_external_test", use_tmpfs = True, - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:creat_test", - vfs2 = "True", ) syscall_test( fuse = "True", test = "//test/syscalls/linux:dev_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:dup_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:epoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:eventfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exceptions_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:exec_binary_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:exit_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fadvise64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fallocate_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fault_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fchdir_test", - vfs2 = "True", ) syscall_test( @@ -223,204 +187,168 @@ syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:flock_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_fork_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:fpsig_nested_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:fsync_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:futex_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_host_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getcpu_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:getdents_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrandom_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:getrusage_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = False, # TODO(gvisor.dev/issue/317): enable when fixed. test = "//test/syscalls/linux:inotify_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:ioctl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:iptables_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 5, test = "//test/syscalls/linux:itimer_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:kill_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:link_test", use_tmpfs = True, # gofer needs CAP_DAC_READ_SEARCH to use AT_EMPTY_PATH with linkat(2) - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:lseek_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:madvise_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:memory_accounting_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mempolicy_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:mincore_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mkdir_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mknod_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:mmap_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:mount_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:mremap_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:msync_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:munmap_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:network_namespace_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_create_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:open_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:packet_socket_raw_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:packet_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:partial_bad_buffer_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pause_test", - vfs2 = "True", ) syscall_test( @@ -428,7 +356,6 @@ syscall_test( # Takes too long under gotsan to run. tags = ["nogotsan"], test = "//test/syscalls/linux:ping_socket_test", - vfs2 = "True", ) syscall_test( @@ -436,229 +363,188 @@ syscall_test( add_overlay = True, shard_count = 5, test = "//test/syscalls/linux:pipe_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:poll_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:ppoll_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_setuid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:prctl_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pread64_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:preadv2_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:priority_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:proc_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_oomscore_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_smaps_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_pid_uid_gid_map_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:pselect_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:ptrace_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:pty_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:pty_root_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwritev2_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:pwrite64_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_hdrincl_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_icmp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:raw_socket_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:read_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:readahead_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:readv_socket_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:readv_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:rename_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rlimits_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rseq_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:rtsignal_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:signalfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sched_yield_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:seccomp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:select_test", - vfs2 = "True", ) syscall_test( shard_count = 20, test = "//test/syscalls/linux:semaphore_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_socket_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sendfile_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:splice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigaction_test", - vfs2 = "True", ) # TODO(b/119826902): Enable once the test passes in runsc. @@ -666,62 +552,52 @@ syscall_test( syscall_test( test = "//test/syscalls/linux:sigiret_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigprocmask_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:sigstop_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sigtimedwait_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:shm_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_abstract_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_domain_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_domain_test", - vfs2 = "True", ) syscall_test( size = "medium", add_overlay = True, test = "//test/syscalls/linux:socket_filesystem_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -729,14 +605,12 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_filesystem_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_inet_loopback_test", - vfs2 = "True", ) syscall_test( @@ -745,122 +619,101 @@ syscall_test( # Takes too long for TSAN. Creates a lot of TCP sockets. tags = ["nogotsan"], test = "//test/syscalls/linux:socket_inet_loopback_nogotsan_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_tcp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 50, test = "//test/syscalls/linux:socket_ip_tcp_udp_generic_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ip_udp_loopback_non_blocking_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_ip_udp_loopback_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_ipv4_udp_unbound_loopback_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_ip_unbound_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netdevice_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_route_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_netlink_uevent_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_blocking_ip_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:socket_non_stream_blocking_udp_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_local_test", - vfs2 = "True", ) syscall_test( size = "large", test = "//test/syscalls/linux:socket_stream_blocking_tcp_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_stream_nonblock_local_test", - vfs2 = "True", ) syscall_test( @@ -868,13 +721,11 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_dgram_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_dgram_non_blocking_test", - vfs2 = "True", ) syscall_test( @@ -882,7 +733,6 @@ syscall_test( add_overlay = True, shard_count = 50, test = "//test/syscalls/linux:socket_unix_pair_test", - vfs2 = "True", ) syscall_test( @@ -890,156 +740,129 @@ syscall_test( size = "enormous", shard_count = 5, test = "//test/syscalls/linux:socket_unix_seqpacket_local_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_stream_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_abstract_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_dgram_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:socket_unix_unbound_filesystem_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:socket_unix_unbound_seqpacket_test", - vfs2 = "True", ) syscall_test( size = "large", shard_count = 50, test = "//test/syscalls/linux:socket_unix_unbound_stream_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:statfs_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:stat_times_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sticky_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:symlink_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:sync_file_range_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysinfo_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:syslog_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:sysret_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 10, test = "//test/syscalls/linux:tcp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tgkill_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timerfd_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:timers_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:time_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tkill_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:truncate_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:tuntap_test", - vfs2 = "True", ) syscall_test( add_hostinet = True, test = "//test/syscalls/linux:tuntap_hostinet_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:udp_bind_test", - vfs2 = "True", ) syscall_test( @@ -1047,80 +870,65 @@ syscall_test( add_hostinet = True, shard_count = 10, test = "//test/syscalls/linux:udp_socket_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uidgid_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:uname_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:unlink_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:unshare_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:utimes_test", - vfs2 = "True", ) syscall_test( size = "medium", test = "//test/syscalls/linux:vdso_clock_gettime_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vdso_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vsyscall_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:vfork_test", - vfs2 = "True", ) syscall_test( size = "medium", shard_count = 5, test = "//test/syscalls/linux:wait_test", - vfs2 = "True", ) syscall_test( add_overlay = True, test = "//test/syscalls/linux:write_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_unix_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_tcp_test", - vfs2 = "True", ) syscall_test( test = "//test/syscalls/linux:proc_net_udp_test", - vfs2 = "True", ) diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc index de0f5f01b..bc005e2bb 100644 --- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc +++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc @@ -2452,5 +2452,105 @@ TEST_P(IPv4UDPUnboundSocketTest, SetSocketSendBuf) { ASSERT_EQ(quarter_sz, val); } + +TEST_P(IPv4UDPUnboundSocketTest, IpMulticastIPPacketInfo) { + auto sender_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + auto receiver_socket = ASSERT_NO_ERRNO_AND_VALUE(NewSocket()); + + // Bind the first FD to the loopback. This is an alternative to + // IP_MULTICAST_IF for setting the default send interface. + auto sender_addr = V4Loopback(); + ASSERT_THAT( + bind(sender_socket->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr), + sender_addr.addr_len), + SyscallSucceeds()); + + // Bind the second FD to the v4 any address to ensure that we can receive the + // multicast packet. + auto receiver_addr = V4Any(); + ASSERT_THAT(bind(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + receiver_addr.addr_len), + SyscallSucceeds()); + socklen_t receiver_addr_len = receiver_addr.addr_len; + ASSERT_THAT(getsockname(receiver_socket->get(), + reinterpret_cast<sockaddr*>(&receiver_addr.addr), + &receiver_addr_len), + SyscallSucceeds()); + EXPECT_EQ(receiver_addr_len, receiver_addr.addr_len); + + // Register to receive multicast packets. + ip_mreqn group = {}; + group.imr_multiaddr.s_addr = inet_addr(kMulticastAddress); + group.imr_ifindex = ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex("lo")); + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_ADD_MEMBERSHIP, + &group, sizeof(group)), + SyscallSucceeds()); + + // Register to receive IP packet info. + const int one = 1; + ASSERT_THAT(setsockopt(receiver_socket->get(), IPPROTO_IP, IP_PKTINFO, &one, + sizeof(one)), + SyscallSucceeds()); + + // Send a multicast packet. + auto send_addr = V4Multicast(); + reinterpret_cast<sockaddr_in*>(&send_addr.addr)->sin_port = + reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port; + char send_buf[200]; + RandomizeBuffer(send_buf, sizeof(send_buf)); + ASSERT_THAT( + RetryEINTR(sendto)(sender_socket->get(), send_buf, sizeof(send_buf), 0, + reinterpret_cast<sockaddr*>(&send_addr.addr), + send_addr.addr_len), + SyscallSucceedsWithValue(sizeof(send_buf))); + + // Check that we received the multicast packet. + msghdr recv_msg = {}; + iovec recv_iov = {}; + char recv_buf[sizeof(send_buf)]; + char recv_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {}; + size_t cmsg_data_len = sizeof(in_pktinfo); + recv_iov.iov_base = recv_buf; + recv_iov.iov_len = sizeof(recv_buf); + recv_msg.msg_iov = &recv_iov; + recv_msg.msg_iovlen = 1; + recv_msg.msg_controllen = CMSG_LEN(cmsg_data_len); + recv_msg.msg_control = recv_cmsg_buf; + ASSERT_THAT(RetryEINTR(recvmsg)(receiver_socket->get(), &recv_msg, 0), + SyscallSucceedsWithValue(sizeof(send_buf))); + EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf))); + + // Check the IP_PKTINFO control message. + cmsghdr* cmsg = CMSG_FIRSTHDR(&recv_msg); + ASSERT_NE(cmsg, nullptr); + EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len)); + EXPECT_EQ(cmsg->cmsg_level, IPPROTO_IP); + EXPECT_EQ(cmsg->cmsg_type, IP_PKTINFO); + + // Get loopback index. + ifreq ifr = {}; + absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo"); + ASSERT_THAT(ioctl(receiver_socket->get(), SIOCGIFINDEX, &ifr), + SyscallSucceeds()); + ASSERT_NE(ifr.ifr_ifindex, 0); + + in_pktinfo received_pktinfo = {}; + memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo)); + EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex); + if (IsRunningOnGvisor()) { + // This should actually be a unicast address assigned to the interface. + // + // TODO(gvisor.dev/issue/3556): This check is validating incorrect + // behaviour. We still include the test so that once the bug is + // resolved, this test will start to fail and the individual tasked + // with fixing this bug knows to also fix this test :). + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, group.imr_multiaddr.s_addr); + } else { + EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK)); + } + EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, group.imr_multiaddr.s_addr); +} + } // namespace testing } // namespace gvisor |