summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--benchmarks/defs.bzl14
-rw-r--r--benchmarks/harness/BUILD165
-rw-r--r--benchmarks/harness/machine_producers/BUILD5
-rw-r--r--benchmarks/runner/BUILD17
-rw-r--r--benchmarks/workloads/ab/BUILD13
-rw-r--r--benchmarks/workloads/absl/BUILD13
-rw-r--r--benchmarks/workloads/fio/BUILD13
-rw-r--r--benchmarks/workloads/iperf/BUILD13
-rw-r--r--benchmarks/workloads/redisbenchmark/BUILD13
-rw-r--r--benchmarks/workloads/sysbench/BUILD13
-rw-r--r--benchmarks/workloads/syscall/BUILD13
-rwxr-xr-xkokoro/runtime_tests/runtime_tests.sh6
-rw-r--r--pkg/abi/linux/socket.go13
-rw-r--r--pkg/binary/binary.go10
-rw-r--r--pkg/sentry/fs/fsutil/host_file_mapper.go17
-rw-r--r--pkg/sentry/fs/tty/slave.go2
-rw-r--r--pkg/sentry/fsimpl/gofer/regular_file.go5
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go69
-rw-r--r--pkg/sentry/socket/hostinet/socket.go11
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go7
-rw-r--r--pkg/sentry/socket/netlink/message.go15
-rw-r--r--pkg/sentry/socket/netstack/netstack.go37
-rw-r--r--pkg/sentry/strace/BUILD1
-rw-r--r--pkg/sentry/strace/socket.go7
-rw-r--r--pkg/sleep/commit_noasm.go13
-rw-r--r--pkg/sleep/sleep_unsafe.go23
-rw-r--r--pkg/tcpip/stack/nic.go5
-rw-r--r--pkg/tcpip/stack/stack.go9
-rw-r--r--pkg/tcpip/stack/stack_test.go85
-rw-r--r--pkg/tcpip/tcpip.go25
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go26
-rw-r--r--runsc/sandbox/network.go30
-rwxr-xr-xscripts/common.sh12
-rwxr-xr-xscripts/common_build.sh27
-rw-r--r--test/runtimes/README.md31
-rw-r--r--test/syscalls/linux/BUILD3
-rw-r--r--test/syscalls/linux/socket_ip_udp_generic.cc44
-rw-r--r--test/syscalls/linux/socket_ipv4_udp_unbound.cc84
-rw-r--r--test/syscalls/linux/udp_socket_test_cases.cc1
-rw-r--r--tools/bazeldefs/defs.bzl2
-rw-r--r--tools/defs.bzl4
-rw-r--r--tools/go_stateify/defs.bzl4
-rw-r--r--tools/go_stateify/main.go10
44 files changed, 688 insertions, 243 deletions
diff --git a/benchmarks/defs.bzl b/benchmarks/defs.bzl
new file mode 100644
index 000000000..56d28223e
--- /dev/null
+++ b/benchmarks/defs.bzl
@@ -0,0 +1,14 @@
+"""Provides attributes common to many workload tests."""
+
+load("//tools:defs.bzl", "py_requirement")
+
+test_deps = [
+ py_requirement("attrs", direct = False),
+ py_requirement("atomicwrites", direct = False),
+ py_requirement("more-itertools", direct = False),
+ py_requirement("pathlib2", direct = False),
+ py_requirement("pluggy", direct = False),
+ py_requirement("py", direct = False),
+ py_requirement("pytest"),
+ py_requirement("six", direct = False),
+]
diff --git a/benchmarks/harness/BUILD b/benchmarks/harness/BUILD
index 4d03e3a06..48c548d59 100644
--- a/benchmarks/harness/BUILD
+++ b/benchmarks/harness/BUILD
@@ -1,5 +1,4 @@
-load("//tools:defs.bzl", "pkg_tar")
-load("//tools:defs.bzl", "py_library", "py_requirement")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -46,16 +45,43 @@ py_library(
srcs = ["container.py"],
deps = [
"//benchmarks/workloads",
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
@@ -68,17 +94,47 @@ py_library(
"//benchmarks/harness:ssh_connection",
"//benchmarks/harness:tunnel_dispatcher",
"//benchmarks/harness/machine_mocks",
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("six", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "six",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
@@ -87,10 +143,16 @@ py_library(
srcs = ["ssh_connection.py"],
deps = [
"//benchmarks/harness",
- py_requirement("bcrypt", False),
- py_requirement("cffi", True),
- py_requirement("paramiko", True),
- py_requirement("cryptography", False),
+ py_requirement(
+ "bcrypt",
+ direct = False,
+ ),
+ py_requirement("cffi"),
+ py_requirement("paramiko"),
+ py_requirement(
+ "cryptography",
+ direct = False,
+ ),
],
)
@@ -98,16 +160,43 @@ py_library(
name = "tunnel_dispatcher",
srcs = ["tunnel_dispatcher.py"],
deps = [
- py_requirement("asn1crypto", False),
- py_requirement("chardet", False),
- py_requirement("certifi", False),
- py_requirement("docker", True),
- py_requirement("docker-pycreds", False),
- py_requirement("idna", False),
- py_requirement("pexpect", True),
- py_requirement("ptyprocess", False),
- py_requirement("requests", False),
- py_requirement("urllib3", False),
- py_requirement("websocket-client", False),
+ py_requirement(
+ "asn1crypto",
+ direct = False,
+ ),
+ py_requirement(
+ "chardet",
+ direct = False,
+ ),
+ py_requirement(
+ "certifi",
+ direct = False,
+ ),
+ py_requirement("docker"),
+ py_requirement(
+ "docker-pycreds",
+ direct = False,
+ ),
+ py_requirement(
+ "idna",
+ direct = False,
+ ),
+ py_requirement("pexpect"),
+ py_requirement(
+ "ptyprocess",
+ direct = False,
+ ),
+ py_requirement(
+ "requests",
+ direct = False,
+ ),
+ py_requirement(
+ "urllib3",
+ direct = False,
+ ),
+ py_requirement(
+ "websocket-client",
+ direct = False,
+ ),
],
)
diff --git a/benchmarks/harness/machine_producers/BUILD b/benchmarks/harness/machine_producers/BUILD
index 3711a397f..81f19bd08 100644
--- a/benchmarks/harness/machine_producers/BUILD
+++ b/benchmarks/harness/machine_producers/BUILD
@@ -31,7 +31,10 @@ py_library(
deps = [
"//benchmarks/harness:machine",
"//benchmarks/harness/machine_producers:machine_producer",
- py_requirement("PyYAML", False),
+ py_requirement(
+ "PyYAML",
+ direct = False,
+ ),
],
)
diff --git a/benchmarks/runner/BUILD b/benchmarks/runner/BUILD
index fae0ca800..471debfdf 100644
--- a/benchmarks/runner/BUILD
+++ b/benchmarks/runner/BUILD
@@ -1,4 +1,5 @@
load("//tools:defs.bzl", "py_library", "py_requirement", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(licenses = ["notice"])
@@ -28,7 +29,7 @@ py_library(
"//benchmarks/suites:startup",
"//benchmarks/suites:sysbench",
"//benchmarks/suites:syscall",
- py_requirement("click", True),
+ py_requirement("click"),
],
)
@@ -36,7 +37,7 @@ py_library(
name = "commands",
srcs = ["commands.py"],
deps = [
- py_requirement("click", True),
+ py_requirement("click"),
],
)
@@ -48,16 +49,8 @@ py_test(
"local",
"manual",
],
- deps = [
+ deps = test_deps + [
":runner",
- py_requirement("click", True),
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
+ py_requirement("click"),
],
)
diff --git a/benchmarks/workloads/ab/BUILD b/benchmarks/workloads/ab/BUILD
index 4dd91ceb3..945ac7026 100644
--- a/benchmarks/workloads/ab/BUILD
+++ b/benchmarks/workloads/ab/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "ab_test",
srcs = ["ab_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":ab",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/absl/BUILD b/benchmarks/workloads/absl/BUILD
index 55dae3baa..bb1a308bf 100644
--- a/benchmarks/workloads/absl/BUILD
+++ b/benchmarks/workloads/absl/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "absl_test",
srcs = ["absl_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":absl",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/fio/BUILD b/benchmarks/workloads/fio/BUILD
index 7b78e8e75..24d909c53 100644
--- a/benchmarks/workloads/fio/BUILD
+++ b/benchmarks/workloads/fio/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "fio_test",
srcs = ["fio_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":fio",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/iperf/BUILD b/benchmarks/workloads/iperf/BUILD
index 570f40148..91b953718 100644
--- a/benchmarks/workloads/iperf/BUILD
+++ b/benchmarks/workloads/iperf/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "iperf_test",
srcs = ["iperf_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":iperf",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/redisbenchmark/BUILD b/benchmarks/workloads/redisbenchmark/BUILD
index f472a4443..147cfedd2 100644
--- a/benchmarks/workloads/redisbenchmark/BUILD
+++ b/benchmarks/workloads/redisbenchmark/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "redisbenchmark_test",
srcs = ["redisbenchmark_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":redisbenchmark",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/sysbench/BUILD b/benchmarks/workloads/sysbench/BUILD
index 3834af7ed..ab2556064 100644
--- a/benchmarks/workloads/sysbench/BUILD
+++ b/benchmarks/workloads/sysbench/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "sysbench_test",
srcs = ["sysbench_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":sysbench",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/benchmarks/workloads/syscall/BUILD b/benchmarks/workloads/syscall/BUILD
index dba4bb1e7..f8c43bca1 100644
--- a/benchmarks/workloads/syscall/BUILD
+++ b/benchmarks/workloads/syscall/BUILD
@@ -1,4 +1,5 @@
-load("//tools:defs.bzl", "pkg_tar", "py_library", "py_requirement", "py_test")
+load("//tools:defs.bzl", "pkg_tar", "py_library", "py_test")
+load("//benchmarks:defs.bzl", "test_deps")
package(
default_visibility = ["//benchmarks:__subpackages__"],
@@ -14,16 +15,8 @@ py_test(
name = "syscall_test",
srcs = ["syscall_test.py"],
python_version = "PY3",
- deps = [
+ deps = test_deps + [
":syscall",
- py_requirement("attrs", False),
- py_requirement("atomicwrites", False),
- py_requirement("more-itertools", False),
- py_requirement("pathlib2", False),
- py_requirement("pluggy", False),
- py_requirement("py", False),
- py_requirement("pytest", True),
- py_requirement("six", False),
],
)
diff --git a/kokoro/runtime_tests/runtime_tests.sh b/kokoro/runtime_tests/runtime_tests.sh
index 9ee991e42..73a58f806 100755
--- a/kokoro/runtime_tests/runtime_tests.sh
+++ b/kokoro/runtime_tests/runtime_tests.sh
@@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-source $(dirname $0)/common.sh
+# Run in the root of the repo.
+cd "$(dirname "$0")"
+cd "$(git rev-parse --show-toplevel)"
+
+source scripts/common.sh
if [ ! -v RUNTIME_TEST_NAME ]; then
echo 'Must set $RUNTIME_TEST_NAME' >&2
diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go
index 766ee4014..4a14ef691 100644
--- a/pkg/abi/linux/socket.go
+++ b/pkg/abi/linux/socket.go
@@ -411,6 +411,15 @@ type ControlMessageCredentials struct {
GID uint32
}
+// A ControlMessageIPPacketInfo is IP_PKTINFO socket control message.
+//
+// ControlMessageIPPacketInfo represents struct in_pktinfo from linux/in.h.
+type ControlMessageIPPacketInfo struct {
+ NIC int32
+ LocalAddr InetAddr
+ DestinationAddr InetAddr
+}
+
// SizeOfControlMessageCredentials is the binary size of a
// ControlMessageCredentials struct.
var SizeOfControlMessageCredentials = int(binary.Size(ControlMessageCredentials{}))
@@ -431,6 +440,10 @@ const SizeOfControlMessageTOS = 1
// SizeOfControlMessageTClass is the size of an IPV6_TCLASS control message.
const SizeOfControlMessageTClass = 4
+// SizeOfControlMessageIPPacketInfo is the size of an IP_PKTINFO
+// control message.
+const SizeOfControlMessageIPPacketInfo = 12
+
// SCM_MAX_FD is the maximum number of FDs accepted in a single sendmsg call.
// From net/scm.h.
const SCM_MAX_FD = 253
diff --git a/pkg/binary/binary.go b/pkg/binary/binary.go
index 631785f7b..25065aef9 100644
--- a/pkg/binary/binary.go
+++ b/pkg/binary/binary.go
@@ -254,3 +254,13 @@ func WriteUint64(w io.Writer, order binary.ByteOrder, num uint64) error {
_, err := w.Write(buf)
return err
}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a length down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go
index 67278aa86..e82afd112 100644
--- a/pkg/sentry/fs/fsutil/host_file_mapper.go
+++ b/pkg/sentry/fs/fsutil/host_file_mapper.go
@@ -65,13 +65,18 @@ type mapping struct {
writable bool
}
-// NewHostFileMapper returns a HostFileMapper with no references or cached
-// mappings.
+// Init must be called on zero-value HostFileMappers before first use.
+func (f *HostFileMapper) Init() {
+ f.refs = make(map[uint64]int32)
+ f.mappings = make(map[uint64]mapping)
+}
+
+// NewHostFileMapper returns an initialized HostFileMapper allocated on the
+// heap with no references or cached mappings.
func NewHostFileMapper() *HostFileMapper {
- return &HostFileMapper{
- refs: make(map[uint64]int32),
- mappings: make(map[uint64]mapping),
- }
+ f := &HostFileMapper{}
+ f.Init()
+ return f
}
// IncRefOn increments the reference count on all offsets in mr.
diff --git a/pkg/sentry/fs/tty/slave.go b/pkg/sentry/fs/tty/slave.go
index db55cdc48..6a2dbc576 100644
--- a/pkg/sentry/fs/tty/slave.go
+++ b/pkg/sentry/fs/tty/slave.go
@@ -73,7 +73,7 @@ func (si *slaveInodeOperations) Release(ctx context.Context) {
}
// Truncate implements fs.InodeOperations.Truncate.
-func (slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
+func (*slaveInodeOperations) Truncate(context.Context, *fs.Inode, int64) error {
return nil
}
diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go
index 8e11e06b3..54c1031a7 100644
--- a/pkg/sentry/fsimpl/gofer/regular_file.go
+++ b/pkg/sentry/fsimpl/gofer/regular_file.go
@@ -571,6 +571,8 @@ func (fd *regularFileFD) ConfigureMMap(ctx context.Context, opts *memmap.MMapOpt
default:
panic(fmt.Sprintf("unknown InteropMode %v", d.fs.opts.interop))
}
+ // After this point, d may be used as a memmap.Mappable.
+ d.pf.hostFileMapperInitOnce.Do(d.pf.hostFileMapper.Init)
return vfs.GenericConfigureMMap(&fd.vfsfd, d, opts)
}
@@ -799,6 +801,9 @@ type dentryPlatformFile struct {
// If this dentry represents a regular file, and handle.fd >= 0,
// hostFileMapper caches mappings of handle.fd.
hostFileMapper fsutil.HostFileMapper
+
+ // hostFileMapperInitOnce is used to lazily initialize hostFileMapper.
+ hostFileMapperInitOnce sync.Once
}
// IncRef implements platform.File.IncRef.
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 79e16d6e8..4d42d29cb 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/syserror",
+ "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00265f15b..4667373d2 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -189,7 +190,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -282,19 +283,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -348,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -372,12 +379,16 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
return buf
}
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -404,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
return space
}
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
@@ -437,7 +458,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -448,7 +469,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -462,7 +483,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
// Unknown message type.
@@ -476,7 +497,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -489,7 +522,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index de76388ac..22f78d2e2 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
@@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
switch name {
case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
@@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
case syscall.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
}
+
case syscall.SOL_IPV6:
switch unixCmsg.Header.Type {
case syscall.IPV6_TCLASS:
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 22fd0ebe7..b4b244abf 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -72,7 +72,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
nflog("marshaling matcher %q", name)
// We have to pad this struct size to a multiple of 8 bytes.
- size := alignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
matcher := linux.KernelXTEntryMatch{
XTEntryMatch: linux.XTEntryMatch{
MatchSize: uint16(size),
@@ -93,8 +93,3 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter,
}
return matchMaker.unmarshal(buf, filter)
}
-
-// alignUp rounds a length up to an alignment. align must be a power of 2.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index 4ea252ccb..0899c61d1 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -23,18 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// alignUp rounds a length up to an alignment.
-//
-// Preconditions: align is a power of two.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) &^ (int(align) - 1)
-}
-
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return alignUp(length, align) - length
+ return binary.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -138,7 +131,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -173,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) {
m.Put(v)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -190,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) {
m.putZeros(1)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed2fbcceb..9757fbfba 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1414,6 +1414,21 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
return o, nil
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ var o int32
+ if v {
+ o = 1
+ }
+ return o, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1762,6 +1777,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1949,6 +1965,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1964,7 +1990,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
@@ -2395,10 +2420,12 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
},
}
}
diff --git a/pkg/sentry/strace/BUILD b/pkg/sentry/strace/BUILD
index 762a946fe..2f39a6f2b 100644
--- a/pkg/sentry/strace/BUILD
+++ b/pkg/sentry/strace/BUILD
@@ -30,7 +30,6 @@ go_library(
"//pkg/seccomp",
"//pkg/sentry/arch",
"//pkg/sentry/kernel",
- "//pkg/sentry/socket/control",
"//pkg/sentry/socket/netlink",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/syscalls/linux",
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index f7ff4573e..51e6d81b2 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -22,7 +22,6 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
slinux "gvisor.dev/gvisor/pkg/sentry/syscalls/linux"
@@ -220,13 +219,13 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
continue
}
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := control.AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
fds := make(linux.ControlMessageRights, numRights)
@@ -295,7 +294,7 @@ func cmsghdr(t *kernel.Task, addr usermem.Addr, length uint64, maxBytes uint64)
default:
panic("unreachable")
}
- i += control.AlignUp(length, width)
+ i += binary.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))
diff --git a/pkg/sleep/commit_noasm.go b/pkg/sleep/commit_noasm.go
index 3af447fb9..f59061f37 100644
--- a/pkg/sleep/commit_noasm.go
+++ b/pkg/sleep/commit_noasm.go
@@ -28,15 +28,6 @@ import "sync/atomic"
// It is written in assembly because it is called from g0, so it doesn't have
// a race context.
func commitSleep(g uintptr, waitingG *uintptr) bool {
- for {
- // Check if the wait was aborted.
- if atomic.LoadUintptr(waitingG) == 0 {
- return false
- }
-
- // Try to store the G so that wakers know who to wake.
- if atomic.CompareAndSwapUintptr(waitingG, preparingG, g) {
- return true
- }
- }
+ // Try to store the G so that wakers know who to wake.
+ return atomic.CompareAndSwapUintptr(waitingG, preparingG, g)
}
diff --git a/pkg/sleep/sleep_unsafe.go b/pkg/sleep/sleep_unsafe.go
index acbf0229b..65bfcf778 100644
--- a/pkg/sleep/sleep_unsafe.go
+++ b/pkg/sleep/sleep_unsafe.go
@@ -299,20 +299,17 @@ func (s *Sleeper) enqueueAssertedWaker(w *Waker) {
}
}
- for {
- // Nothing to do if there isn't a G waiting.
- g := atomic.LoadUintptr(&s.waitingG)
- if g == 0 {
- return
- }
+ // Nothing to do if there isn't a G waiting.
+ if atomic.LoadUintptr(&s.waitingG) == 0 {
+ return
+ }
- // Signal to the sleeper that a waker has been asserted.
- if atomic.CompareAndSwapUintptr(&s.waitingG, g, 0) {
- if g != preparingG {
- // We managed to get a G. Wake it up.
- goready(g, 0)
- }
- }
+ // Signal to the sleeper that a waker has been asserted.
+ switch g := atomic.SwapUintptr(&s.waitingG, 0); g {
+ case 0, preparingG:
+ default:
+ // We managed to get a G. Wake it up.
+ goready(g, 0)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 78d451cca..ca3a7a07e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -1215,6 +1215,11 @@ func (n *NIC) ID() tcpip.NICID {
return n.id
}
+// Name returns the name of n.
+func (n *NIC) Name() string {
+ return n.name
+}
+
// Stack returns the instance of the Stack that owns this NIC.
func (n *NIC) Stack() *Stack {
return n.stack
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index b793f1d74..6eac16e16 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -890,6 +890,15 @@ func (s *Stack) CreateNICWithOptions(id tcpip.NICID, ep LinkEndpoint, opts NICOp
return tcpip.ErrDuplicateNICID
}
+ // Make sure name is unique, unless unnamed.
+ if opts.Name != "" {
+ for _, n := range s.nics {
+ if n.Name() == opts.Name {
+ return tcpip.ErrDuplicateNICID
+ }
+ }
+ }
+
n := newNIC(s, id, opts.Name, ep, opts.Context)
s.nics[id] = n
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 24133e6f2..7ba604442 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1792,6 +1792,91 @@ func TestAddProtocolAddressWithOptions(t *testing.T) {
verifyAddresses(t, expectedAddresses, gotAddresses)
}
+func TestCreateNICWithOptions(t *testing.T) {
+ type callArgsAndExpect struct {
+ nicID tcpip.NICID
+ opts stack.NICOptions
+ err *tcpip.Error
+ }
+
+ tests := []struct {
+ desc string
+ calls []callArgsAndExpect
+ }{
+ {
+ desc: "DuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth1"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "eth2"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "DuplicateName",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{Name: "lo"},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{Name: "lo"},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ {
+ desc: "Unnamed",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(2),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ },
+ },
+ {
+ desc: "UnnamedDuplicateNICID",
+ calls: []callArgsAndExpect{
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: nil,
+ },
+ {
+ nicID: tcpip.NICID(1),
+ opts: stack.NICOptions{},
+ err: tcpip.ErrDuplicateNICID,
+ },
+ },
+ },
+ }
+ for _, test := range tests {
+ t.Run(test.desc, func(t *testing.T) {
+ s := stack.New(stack.Options{})
+ ep := channel.New(0, 0, tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"))
+ for _, call := range test.calls {
+ if got, want := s.CreateNICWithOptions(call.nicID, ep, call.opts), call.err; got != want {
+ t.Fatalf("CreateNICWithOptions(%v, _, %+v) = %v, want %v", call.nicID, call.opts, got, want)
+ }
+ }
+ })
+ }
+}
+
func TestNICStats(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()},
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 0e944712f..9ca39ce40 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -328,6 +328,12 @@ type ControlMessages struct {
// Tclass is the IPv6 traffic class of the associated packet.
TClass int32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo IPPacketInfo
}
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
@@ -503,6 +509,11 @@ const (
// V6OnlyOption is used by {G,S}etSockOptBool to specify whether an IPv6
// socket is to be restricted to sending and receiving IPv6 packets only.
V6OnlyOption
+
+ // ReceiveIPPacketInfoOption is used by {G,S}etSockOptBool to specify
+ // if more inforamtion is provided with incoming packets such
+ // as interface index and address.
+ ReceiveIPPacketInfoOption
)
// SockOptInt represents socket options which values have the int type.
@@ -685,6 +696,20 @@ type IPv4TOSOption uint8
// for all subsequent outgoing IPv6 packets from the endpoint.
type IPv6TrafficClassOption uint8
+// IPPacketInfo is the message struture for IP_PKTINFO.
+//
+// +stateify savable
+type IPPacketInfo struct {
+ // NIC is the ID of the NIC to be used.
+ NIC NICID
+
+ // LocalAddr is the local address.
+ LocalAddr Address
+
+ // DestinationAddr is the destination address.
+ DestinationAddr Address
+}
+
// Route is a row in the routing table. It specifies through which NIC (and
// gateway) sets of packets should be routed. A row is considered viable if the
// masked target address matches the destination address in the row.
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index c9cbed8f4..3fe91cac2 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -29,6 +29,7 @@ import (
type udpPacket struct {
udpPacketEntry
senderAddress tcpip.FullAddress
+ packetInfo tcpip.IPPacketInfo
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
timestamp int64
tos uint8
@@ -118,6 +119,9 @@ type endpoint struct {
// as ancillary data to ControlMessages on Read.
receiveTOS bool
+ // receiveIPPacketInfo determines if the packet info is returned by Read.
+ receiveIPPacketInfo bool
+
// shutdownFlags represent the current shutdown state of the endpoint.
shutdownFlags tcpip.ShutdownFlags
@@ -254,11 +258,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
}
e.mu.RLock()
receiveTOS := e.receiveTOS
+ receiveIPPacketInfo := e.receiveIPPacketInfo
e.mu.RUnlock()
if receiveTOS {
cm.HasTOS = true
cm.TOS = p.tos
}
+
+ if receiveIPPacketInfo {
+ cm.HasIPPacketInfo = true
+ cm.PacketInfo = p.packetInfo
+ }
return p.data.ToView(), cm, nil
}
@@ -495,6 +505,13 @@ func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
}
e.v6only = v
+ return nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.Lock()
+ e.receiveIPPacketInfo = v
+ e.mu.Unlock()
+ return nil
}
return nil
@@ -703,6 +720,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
e.mu.RUnlock()
return v, nil
+
+ case tcpip.ReceiveIPPacketInfoOption:
+ e.mu.RLock()
+ v := e.receiveIPPacketInfo
+ e.mu.RUnlock()
+ return v, nil
}
return false, tcpip.ErrUnknownProtocolOption
@@ -1247,6 +1270,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk
switch r.NetProto {
case header.IPv4ProtocolNumber:
packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS()
+ packet.packetInfo.LocalAddr = r.LocalAddress
+ packet.packetInfo.DestinationAddr = r.RemoteAddress
+ packet.packetInfo.NIC = r.NICID()
}
packet.timestamp = e.stack.NowNanoseconds()
diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go
index ff48f5646..99e143696 100644
--- a/runsc/sandbox/network.go
+++ b/runsc/sandbox/network.go
@@ -174,13 +174,13 @@ func createInterfacesAndRoutesFromNS(conn *urpc.Client, nsPath string, hardwareG
return fmt.Errorf("fetching interface addresses for %q: %v", iface.Name, err)
}
- // We build our own loopback devices.
+ // We build our own loopback device.
if iface.Flags&net.FlagLoopback != 0 {
- links, err := loopbackLinks(iface, allAddrs)
+ link, err := loopbackLink(iface, allAddrs)
if err != nil {
- return fmt.Errorf("getting loopback routes and links for iface %q: %v", iface.Name, err)
+ return fmt.Errorf("getting loopback link for iface %q: %v", iface.Name, err)
}
- args.LoopbackLinks = append(args.LoopbackLinks, links...)
+ args.LoopbackLinks = append(args.LoopbackLinks, link)
continue
}
@@ -339,25 +339,25 @@ func createSocket(iface net.Interface, ifaceLink netlink.Link, enableGSO bool) (
return &socketEntry{deviceFile, gsoMaxSize}, nil
}
-// loopbackLinks collects the links for a loopback interface.
-func loopbackLinks(iface net.Interface, addrs []net.Addr) ([]boot.LoopbackLink, error) {
- var links []boot.LoopbackLink
+// loopbackLink returns the link with addresses and routes for a loopback
+// interface.
+func loopbackLink(iface net.Interface, addrs []net.Addr) (boot.LoopbackLink, error) {
+ link := boot.LoopbackLink{
+ Name: iface.Name,
+ }
for _, addr := range addrs {
ipNet, ok := addr.(*net.IPNet)
if !ok {
- return nil, fmt.Errorf("address is not IPNet: %+v", addr)
+ return boot.LoopbackLink{}, fmt.Errorf("address is not IPNet: %+v", addr)
}
dst := *ipNet
dst.IP = dst.IP.Mask(dst.Mask)
- links = append(links, boot.LoopbackLink{
- Name: iface.Name,
- Addresses: []net.IP{ipNet.IP},
- Routes: []boot.Route{{
- Destination: dst,
- }},
+ link.Addresses = append(link.Addresses, ipNet.IP)
+ link.Routes = append(link.Routes, boot.Route{
+ Destination: dst,
})
}
- return links, nil
+ return link, nil
}
// routesForIface iterates over all routes for the given interface and converts
diff --git a/scripts/common.sh b/scripts/common.sh
index cd91b9f8e..3ca699e4a 100755
--- a/scripts/common.sh
+++ b/scripts/common.sh
@@ -16,7 +16,17 @@
set -xeou pipefail
-source $(dirname $0)/common_build.sh
+# Get the path to the directory this script lives in.
+# If this script is being called with `source`, $0 will be the path of the
+# *sourcing* script, so we can't use `dirname $0` to find scripts in this
+# directory.
+if [[ -v BASH_SOURCE && "$0" != "$BASH_SOURCE" ]]; then
+ declare -r script_dir="$(dirname "$BASH_SOURCE")"
+else
+ declare -r script_dir="$(dirname "$0")"
+fi
+
+source "${script_dir}/common_build.sh"
# Ensure it attempts to collect logs in all cases.
trap collect_logs EXIT
diff --git a/scripts/common_build.sh b/scripts/common_build.sh
index 2c2a826c7..ae8b67383 100755
--- a/scripts/common_build.sh
+++ b/scripts/common_build.sh
@@ -14,8 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# Install the latest version of Bazel and log the version.
-(which use_bazel.sh && use_bazel.sh latest) || which bazel
+which bazel
bazel version
# Switch into the workspace; only necessary if run with kokoro.
@@ -26,27 +25,30 @@ elif [[ -v KOKORO_GIT_COMMIT ]] && [[ -d github/repo ]]; then
fi
# Set the standard bazel flags.
-declare -r BAZEL_FLAGS=(
+declare -a BAZEL_FLAGS=(
"--show_timestamps"
"--test_output=errors"
"--keep_going"
"--verbose_failures=true"
)
-BAZEL_RBE_AUTH_FLAGS=""
-BAZEL_RBE_FLAGS=""
if [[ -v KOKORO_BAZEL_AUTH_CREDENTIAL ]]; then
- declare -r BAZEL_RBE_AUTH_FLAGS="--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
- declare -r BAZEL_RBE_FLAGS="--config=remote"
+ BAZEL_FLAGS+=(
+ "--auth_credentials=${KOKORO_BAZEL_AUTH_CREDENTIAL}"
+ "--config=remote"
+ )
fi
+declare -r BAZEL_FLAGS
# Wrap bazel.
function build() {
- bazel build "${BAZEL_RBE_FLAGS}" "${BAZEL_RBE_AUTH_FLAGS}" "${BAZEL_FLAGS[@]}" "$@" 2>&1 |
- tee /dev/fd/2 | grep -E '^ bazel-bin/' | awk '{ print $1; }'
+ bazel build "${BAZEL_FLAGS[@]}" "$@" 2>&1 \
+ | tee /dev/fd/2 \
+ | grep -E '^ bazel-bin/' \
+ | awk '{ print $1; }'
}
function test() {
- bazel test "${BAZEL_RBE_FLAGS}" "${BAZEL_RBE_AUTH_FLAGS}" "${BAZEL_FLAGS[@]}" "$@"
+ bazel test "${BAZEL_FLAGS[@]}" "$@"
}
function run() {
@@ -95,5 +97,8 @@ function collect_logs() {
}
function find_branch_name() {
- git branch --show-current || git rev-parse HEAD || bazel info workspace | xargs basename
+ git branch --show-current \
+ || git rev-parse HEAD \
+ || bazel info workspace \
+ | xargs basename
}
diff --git a/test/runtimes/README.md b/test/runtimes/README.md
index e41e78f77..42d722553 100644
--- a/test/runtimes/README.md
+++ b/test/runtimes/README.md
@@ -12,24 +12,39 @@ The following runtimes are currently supported:
- PHP 7.3
- Python 3.7
-#### Prerequisites:
+### Building and pushing the images:
-1) [Install and configure Docker](https://docs.docker.com/install/)
-
-2) Build each Docker container from the runtimes/images directory:
+The canonical source of images is the
+[gvisor-presubmit container registry](https://gcr.io/gvisor-presubmit/). You can
+build new images with the following command:
```bash
$ cd images
$ docker build -f Dockerfile_$LANG [-t $NAME] .
```
-### Testing:
+To push them to our container registry, set the tag in the command above to
+`gcr.io/gvisor-presubmit/$LANG`, then push them. (Note that you will need
+appropriate permissions to the `gvisor-presubmit` GCP project.)
+
+```bash
+gcloud docker -- push gcr.io/gvisor-presubmit/$LANG
+```
+
+#### Running in Docker locally:
+
+1) [Install and configure Docker](https://docs.docker.com/install/)
+
+2) Pull the image you want to run:
+
+```bash
+$ docker pull gcr.io/gvisor-presubmit/$LANG
+```
-If the prerequisites have been fulfilled, you can run the tests with the
-following command:
+3) Run docker with the image.
```bash
-$ docker run --rm -it $NAME [FLAG]
+$ docker run [--runtime=runsc] --rm -it $NAME [FLAG]
```
Running the command with no flags will cause all the available tests to execute.
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index ca1af209a..e7c82adfc 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -10,13 +10,16 @@ exports_files(
"socket.cc",
"socket_inet_loopback.cc",
"socket_ip_loopback_blocking.cc",
+ "socket_ip_tcp_generic_loopback.cc",
"socket_ip_tcp_loopback.cc",
+ "socket_ip_tcp_udp_generic.cc",
"socket_ip_udp_loopback.cc",
"socket_ip_unbound.cc",
"socket_ipv4_tcp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_external_networking_test.cc",
"socket_ipv4_udp_unbound_loopback.cc",
"tcp_socket.cc",
+ "udp_bind.cc",
"udp_socket.cc",
],
visibility = ["//:sandbox"],
diff --git a/test/syscalls/linux/socket_ip_udp_generic.cc b/test/syscalls/linux/socket_ip_udp_generic.cc
index 53290bed7..db5663ecd 100644
--- a/test/syscalls/linux/socket_ip_udp_generic.cc
+++ b/test/syscalls/linux/socket_ip_udp_generic.cc
@@ -357,5 +357,49 @@ TEST_P(UDPSocketPairTest, SetReuseAddrReusePort) {
EXPECT_EQ(get, kSockOptOn);
}
+// Test getsockopt for a socket which is not set with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, IPPKTINFODefault) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(
+ getsockopt(sockets->first_fd(), SOL_IP, IP_PKTINFO, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get_len, sizeof(get));
+ EXPECT_EQ(get, kSockOptOff);
+}
+
+// Test setsockopt and getsockopt for a socket with IP_PKTINFO option.
+TEST_P(UDPSocketPairTest, SetAndGetIPPKTINFO) {
+ auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair());
+
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ // Check getsockopt before IP_PKTINFO is set.
+ int get = -1;
+ socklen_t get_len = sizeof(get);
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOn,
+ sizeof(kSockOptOn)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOn);
+ EXPECT_EQ(get_len, sizeof(get));
+
+ ASSERT_THAT(setsockopt(sockets->first_fd(), level, type, &kSockOptOff,
+ sizeof(kSockOptOff)),
+ SyscallSucceedsWithValue(0));
+
+ ASSERT_THAT(getsockopt(sockets->first_fd(), level, type, &get, &get_len),
+ SyscallSucceedsWithValue(0));
+ EXPECT_EQ(get, kSockOptOff);
+ EXPECT_EQ(get_len, sizeof(get));
+}
+
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/socket_ipv4_udp_unbound.cc b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
index 990ccf23c..bc4b07a62 100644
--- a/test/syscalls/linux/socket_ipv4_udp_unbound.cc
+++ b/test/syscalls/linux/socket_ipv4_udp_unbound.cc
@@ -15,6 +15,7 @@
#include "test/syscalls/linux/socket_ipv4_udp_unbound.h"
#include <arpa/inet.h>
+#include <net/if.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <sys/un.h>
@@ -2128,5 +2129,88 @@ TEST_P(IPv4UDPUnboundSocketTest, ReuseAddrReusePortDistribution) {
SyscallSucceedsWithValue(kMessageSize));
}
+// Test that socket will receive packet info control message.
+TEST_P(IPv4UDPUnboundSocketTest, SetAndReceiveIPPKTINFO) {
+ // TODO(gvisor.dev/issue/1202): ioctl() is not supported by hostinet.
+ SKIP_IF((IsRunningWithHostinet()));
+
+ auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
+ auto sender_addr = V4Loopback();
+ int level = SOL_IP;
+ int type = IP_PKTINFO;
+
+ ASSERT_THAT(
+ bind(receiver->get(), reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ sender_addr.addr_len),
+ SyscallSucceeds());
+ socklen_t sender_addr_len = sender_addr.addr_len;
+ ASSERT_THAT(getsockname(receiver->get(),
+ reinterpret_cast<sockaddr*>(&sender_addr.addr),
+ &sender_addr_len),
+ SyscallSucceeds());
+ EXPECT_EQ(sender_addr_len, sender_addr.addr_len);
+
+ auto receiver_addr = V4Loopback();
+ reinterpret_cast<sockaddr_in*>(&receiver_addr.addr)->sin_port =
+ reinterpret_cast<sockaddr_in*>(&sender_addr.addr)->sin_port;
+ ASSERT_THAT(
+ connect(sender->get(), reinterpret_cast<sockaddr*>(&receiver_addr.addr),
+ receiver_addr.addr_len),
+ SyscallSucceeds());
+
+ // Allow socket to receive control message.
+ ASSERT_THAT(
+ setsockopt(receiver->get(), level, type, &kSockOptOn, sizeof(kSockOptOn)),
+ SyscallSucceeds());
+
+ // Prepare message to send.
+ constexpr size_t kDataLength = 1024;
+ msghdr sent_msg = {};
+ iovec sent_iov = {};
+ char sent_data[kDataLength];
+ sent_iov.iov_base = sent_data;
+ sent_iov.iov_len = kDataLength;
+ sent_msg.msg_iov = &sent_iov;
+ sent_msg.msg_iovlen = 1;
+ sent_msg.msg_flags = 0;
+
+ ASSERT_THAT(RetryEINTR(sendmsg)(sender->get(), &sent_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ msghdr received_msg = {};
+ iovec received_iov = {};
+ char received_data[kDataLength];
+ char received_cmsg_buf[CMSG_SPACE(sizeof(in_pktinfo))] = {};
+ size_t cmsg_data_len = sizeof(in_pktinfo);
+ received_iov.iov_base = received_data;
+ received_iov.iov_len = kDataLength;
+ received_msg.msg_iov = &received_iov;
+ received_msg.msg_iovlen = 1;
+ received_msg.msg_controllen = CMSG_LEN(cmsg_data_len);
+ received_msg.msg_control = received_cmsg_buf;
+
+ ASSERT_THAT(RetryEINTR(recvmsg)(receiver->get(), &received_msg, 0),
+ SyscallSucceedsWithValue(kDataLength));
+
+ cmsghdr* cmsg = CMSG_FIRSTHDR(&received_msg);
+ ASSERT_NE(cmsg, nullptr);
+ EXPECT_EQ(cmsg->cmsg_len, CMSG_LEN(cmsg_data_len));
+ EXPECT_EQ(cmsg->cmsg_level, level);
+ EXPECT_EQ(cmsg->cmsg_type, type);
+
+ // Get loopback index.
+ ifreq ifr = {};
+ absl::SNPrintF(ifr.ifr_name, IFNAMSIZ, "lo");
+ ASSERT_THAT(ioctl(sender->get(), SIOCGIFINDEX, &ifr), SyscallSucceeds());
+ ASSERT_NE(ifr.ifr_ifindex, 0);
+
+ // Check the data
+ in_pktinfo received_pktinfo = {};
+ memcpy(&received_pktinfo, CMSG_DATA(cmsg), sizeof(in_pktinfo));
+ EXPECT_EQ(received_pktinfo.ipi_ifindex, ifr.ifr_ifindex);
+ EXPECT_EQ(received_pktinfo.ipi_spec_dst.s_addr, htonl(INADDR_LOOPBACK));
+ EXPECT_EQ(received_pktinfo.ipi_addr.s_addr, htonl(INADDR_LOOPBACK));
+}
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/udp_socket_test_cases.cc b/test/syscalls/linux/udp_socket_test_cases.cc
index a2f6ef8cc..9f8de6b48 100644
--- a/test/syscalls/linux/udp_socket_test_cases.cc
+++ b/test/syscalls/linux/udp_socket_test_cases.cc
@@ -1495,6 +1495,5 @@ TEST_P(UdpSocketTest, SendAndReceiveTOS) {
memcpy(&received_tos, CMSG_DATA(cmsg), sizeof(received_tos));
EXPECT_EQ(received_tos, sent_tos);
}
-
} // namespace testing
} // namespace gvisor
diff --git a/tools/bazeldefs/defs.bzl b/tools/bazeldefs/defs.bzl
index 08c29ff1c..6798362dc 100644
--- a/tools/bazeldefs/defs.bzl
+++ b/tools/bazeldefs/defs.bzl
@@ -72,7 +72,7 @@ def go_test(name, **kwargs):
**kwargs
)
-def py_requirement(name, direct = False):
+def py_requirement(name, direct = True):
return _py_requirement(name)
def select_arch(amd64 = "amd64", arm64 = "arm64", default = None, **kwargs):
diff --git a/tools/defs.bzl b/tools/defs.bzl
index d4690cc1a..46249f9c4 100644
--- a/tools/defs.bzl
+++ b/tools/defs.bzl
@@ -110,6 +110,8 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
"""
all_srcs = srcs
all_deps = deps
+ dirname, _, _ = native.package_name().rpartition("/")
+ full_pkg = dirname + "/" + name
if stateify:
# Only do stateification for non-state packages without manual autogen.
# First, we need to segregate the input files via the special suffixes,
@@ -120,7 +122,7 @@ def go_library(name, srcs, deps = [], imports = [], stateify = True, marshal = F
name = name + suffix + "_state_autogen_with_imports",
srcs = srcs,
imports = imports,
- package = name,
+ package = full_pkg,
out = name + suffix + "_state_autogen_with_imports.go",
)
go_imports(
diff --git a/tools/go_stateify/defs.bzl b/tools/go_stateify/defs.bzl
index bdb966362..6a5e666f0 100644
--- a/tools/go_stateify/defs.bzl
+++ b/tools/go_stateify/defs.bzl
@@ -6,7 +6,7 @@ def _go_stateify_impl(ctx):
# Run the stateify command.
args = ["-output=%s" % output.path]
- args.append("-pkg=%s" % ctx.attr.package)
+ args.append("-fullpkg=%s" % ctx.attr.package)
if ctx.attr._statepkg:
args.append("-statepkg=%s" % ctx.attr._statepkg)
if ctx.attr.imports:
@@ -43,7 +43,7 @@ for statified types.
mandatory = False,
),
"package": attr.string(
- doc = "The package name for the input sources.",
+ doc = "The fully qualified package name for the input sources.",
mandatory = True,
),
"out": attr.output(
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index aa9d4543e..3437aa476 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -23,6 +23,7 @@ import (
"go/parser"
"go/token"
"os"
+ "path/filepath"
"reflect"
"strings"
"sync"
@@ -31,7 +32,7 @@ import (
)
var (
- pkg = flag.String("pkg", "", "output package")
+ fullPkg = flag.String("fullpkg", "", "fully qualified output package")
imports = flag.String("imports", "", "extra imports for the output file")
output = flag.String("output", "", "output file")
statePkg = flag.String("statepkg", "", "state import package; defaults to empty")
@@ -170,7 +171,7 @@ func main() {
flag.Usage()
os.Exit(1)
}
- if *pkg == "" {
+ if *fullPkg == "" {
fmt.Fprintf(os.Stderr, "Error: package required.")
os.Exit(1)
}
@@ -202,7 +203,7 @@ func main() {
// Declare our emission closures.
emitRegister := func(name string) {
- initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *pkg, name, name, name, name))
+ initCalls = append(initCalls, fmt.Sprintf("%sRegister(\"%s.%s\", (*%s)(nil), state.Fns{Save: (*%s).save, Load: (*%s).load})", statePrefix, *fullPkg, name, name, name, name))
}
emitZeroCheck := func(name string) {
fmt.Fprintf(outputFile, " if !%sIsZeroValue(x.%s) { m.Failf(\"%s is %%v, expected zero\", x.%s) }\n", statePrefix, name, name, name)
@@ -233,7 +234,8 @@ func main() {
}
// Emit the package name.
- fmt.Fprintf(outputFile, "package %s\n\n", *pkg)
+ _, pkg := filepath.Split(*fullPkg)
+ fmt.Fprintf(outputFile, "package %s\n\n", pkg)
// Emit the imports lazily.
var once sync.Once