summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netlink
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netlink')
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/message.go40
-rw-r--r--pkg/sentry/socket/netlink/message_test.go18
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD1
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go30
-rw-r--r--pkg/sentry/socket/netlink/socket.go21
6 files changed, 68 insertions, 46 deletions
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 171b95c63..64cd263da 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -14,7 +14,7 @@ go_library(
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
- "//pkg/binary",
+ "//pkg/bits",
"//pkg/context",
"//pkg/hostarch",
"//pkg/marshal",
@@ -50,5 +50,7 @@ go_test(
deps = [
":netlink",
"//pkg/abi/linux",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index ab0e68af7..80385bfdc 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -19,15 +19,17 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/hostarch"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
)
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return binary.AlignUp(length, align) - length
+ return bits.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -42,7 +44,7 @@ type Message struct {
func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
return &Message{
hdr: hdr,
- buf: binary.Marshal(nil, hostarch.ByteOrder, hdr),
+ buf: marshal.Marshal(&hdr),
}
}
@@ -58,7 +60,7 @@ func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
return
}
var hdr linux.NetlinkMessageHeader
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
// Msg portion.
totalMsgLen := int(hdr.Length)
@@ -92,7 +94,7 @@ func (m *Message) Header() linux.NetlinkMessageHeader {
// GetData unmarshals the payload message header from this netlink message, and
// returns the attributes portion.
-func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+func (m *Message) GetData(msg marshal.Marshallable) (AttrsView, bool) {
b := BytesView(m.buf)
_, ok := b.Extract(linux.NetlinkMessageHeaderSize)
@@ -100,12 +102,12 @@ func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
return nil, false
}
- size := int(binary.Size(msg))
+ size := msg.SizeBytes()
msgBytes, ok := b.Extract(size)
if !ok {
return nil, false
}
- binary.Unmarshal(msgBytes, hostarch.ByteOrder, msg)
+ msg.UnmarshalUnsafe(msgBytes)
numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
// Linux permits the last message not being aligned, just consume all of it.
@@ -131,7 +133,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := bits.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -145,45 +147,45 @@ func (m *Message) putZeros(n int) {
}
// Put serializes v into the message.
-func (m *Message) Put(v interface{}) {
- m.buf = binary.Marshal(m.buf, hostarch.ByteOrder, v)
+func (m *Message) Put(v marshal.Marshallable) {
+ m.buf = append(m.buf, marshal.Marshal(v)...)
}
// PutAttr adds v to the message as a netlink attribute.
//
// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
-// binary.Size(v) fits in math.MaxUint16 bytes.
-func (m *Message) PutAttr(atype uint16, v interface{}) {
- l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+// v.SizeBytes()) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v marshal.Marshallable) {
+ l := linux.NetlinkAttrHeaderSize + v.SizeBytes()
if l > math.MaxUint16 {
panic(fmt.Sprintf("attribute too large: %d", l))
}
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
m.Put(v)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
// PutAttrString adds s to the message as a netlink attribute.
func (m *Message) PutAttrString(atype uint16, s string) {
l := linux.NetlinkAttrHeaderSize + len(s) + 1
- m.Put(linux.NetlinkAttrHeader{
+ m.Put(&linux.NetlinkAttrHeader{
Type: atype,
Length: uint16(l),
})
// String + NUL-termination.
- m.Put([]byte(s))
+ m.Put(primitive.AsByteSlice([]byte(s)))
m.putZeros(1)
// Align the attribute.
- aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ aligned := bits.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -251,7 +253,7 @@ func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest
if !ok {
return
}
- binary.Unmarshal(hdrBytes, hostarch.ByteOrder, &hdr)
+ hdr.UnmarshalUnsafe(hdrBytes)
value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
if !ok {
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index ef13d9386..968968469 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -20,13 +20,31 @@ import (
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
type dummyNetlinkMsg struct {
+ marshal.StubMarshallable
Foo uint16
}
+func (*dummyNetlinkMsg) SizeBytes() int {
+ return 2
+}
+
+func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
+ p := primitive.Uint16(m.Foo)
+ p.MarshalUnsafe(dst)
+}
+
+func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
+ var p primitive.Uint16
+ p.UnmarshalUnsafe(src)
+ m.Foo = uint16(p)
+}
+
func TestParseMessage(t *testing.T) {
tests := []struct {
desc string
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 744fc74f4..c6c04b4e3 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -11,6 +11,7 @@ go_library(
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/marshal/primitive",
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 5a2255db3..86f6419dc 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -21,6 +21,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -167,7 +168,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
Type: linux.RTM_NEWLINK,
})
- m.Put(linux.InterfaceInfoMessage{
+ m.Put(&linux.InterfaceInfoMessage{
Family: linux.AF_UNSPEC,
Type: i.DeviceType,
Index: idx,
@@ -175,7 +176,7 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
})
m.PutAttrString(linux.IFLA_IFNAME, i.Name)
- m.PutAttr(linux.IFLA_MTU, i.MTU)
+ m.PutAttr(linux.IFLA_MTU, primitive.AllocateUint32(i.MTU))
mac := make([]byte, 6)
brd := mac
@@ -183,8 +184,8 @@ func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
mac = i.Addr
brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
}
- m.PutAttr(linux.IFLA_ADDRESS, mac)
- m.PutAttr(linux.IFLA_BROADCAST, brd)
+ m.PutAttr(linux.IFLA_ADDRESS, primitive.AsByteSlice(mac))
+ m.PutAttr(linux.IFLA_BROADCAST, primitive.AsByteSlice(brd))
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -216,14 +217,15 @@ func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netl
Type: linux.RTM_NEWADDR,
})
- m.Put(linux.InterfaceAddrMessage{
+ m.Put(&linux.InterfaceAddrMessage{
Family: a.Family,
PrefixLen: a.PrefixLen,
Index: uint32(id),
})
- m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
- m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+ addr := primitive.ByteSlice([]byte(a.Addr))
+ m.PutAttr(linux.IFA_LOCAL, &addr)
+ m.PutAttr(linux.IFA_ADDRESS, &addr)
// TODO(gvisor.dev/issue/578): There are many more attributes.
}
@@ -366,7 +368,7 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Type: linux.RTM_NEWROUTE,
})
- m.Put(linux.RouteMessage{
+ m.Put(&linux.RouteMessage{
Family: rt.Family,
DstLen: rt.DstLen,
SrcLen: rt.SrcLen,
@@ -382,18 +384,18 @@ func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *net
Flags: rt.Flags,
})
- m.PutAttr(254, []byte{123})
+ m.PutAttr(254, primitive.AsByteSlice([]byte{123}))
if rt.DstLen > 0 {
- m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ m.PutAttr(linux.RTA_DST, primitive.AsByteSlice(rt.DstAddr))
}
if rt.SrcLen > 0 {
- m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ m.PutAttr(linux.RTA_SRC, primitive.AsByteSlice(rt.SrcAddr))
}
if rt.OutputInterface != 0 {
- m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ m.PutAttr(linux.RTA_OIF, primitive.AllocateInt32(rt.OutputInterface))
}
if len(rt.GatewayAddr) > 0 {
- m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ m.PutAttr(linux.RTA_GATEWAY, primitive.AsByteSlice(rt.GatewayAddr))
}
// TODO(gvisor.dev/issue/578): There are many more attributes.
@@ -503,7 +505,7 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
hdr := msg.Header()
// All messages start with a 1 byte protocol family.
- var family uint8
+ var family primitive.Uint8
if _, ok := msg.GetData(&family); !ok {
// Linux ignores messages missing the protocol family. See
// net/core/rtnetlink.c:rtnetlink_rcv_msg.
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 30c297149..d75a2879f 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -20,7 +20,6 @@ import (
"math"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/hostarch"
"gvisor.dev/gvisor/pkg/marshal"
@@ -223,7 +222,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- binary.Unmarshal(b[:linux.SockAddrNetlinkSize], hostarch.ByteOrder, &sa)
+ sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
@@ -338,16 +337,14 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
}
s.mu.Lock()
defer s.mu.Unlock()
- sendBufferSizeP := primitive.Int32(s.sendBufferSize)
- return &sendBufferSizeP, nil
+ return primitive.AllocateInt32(int32(s.sendBufferSize)), nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
// We don't have limit on receiving size.
- recvBufferSizeP := primitive.Int32(math.MaxInt32)
- return &recvBufferSizeP, nil
+ return primitive.AllocateInt32(math.MaxInt32), nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -484,7 +481,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *
Family: linux.AF_NETLINK,
PortID: uint32(s.portID),
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// GetPeerName implements socket.Socket.GetPeerName.
@@ -495,7 +492,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *
// must be the kernel.
PortID: 0,
}
- return sa, uint32(binary.Size(sa)), nil
+ return sa, uint32(sa.SizeBytes()), nil
}
// RecvMsg implements socket.Socket.RecvMsg.
@@ -504,7 +501,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
Family: linux.AF_NETLINK,
PortID: 0,
}
- fromLen := uint32(binary.Size(from))
+ fromLen := uint32(from.SizeBytes())
trunc := flags&linux.MSG_TRUNC != 0
@@ -640,7 +637,7 @@ func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *sys
})
// Add the dump_done_errno payload.
- m.Put(int64(0))
+ m.Put(primitive.AllocateInt64(0))
_, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
@@ -658,7 +655,7 @@ func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: int32(-err.ToLinux().Number()),
Header: hdr,
})
@@ -668,7 +665,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
m := ms.AddMessage(linux.NetlinkMessageHeader{
Type: linux.NLMSG_ERROR,
})
- m.Put(linux.NetlinkErrorMessage{
+ m.Put(&linux.NetlinkErrorMessage{
Error: 0,
Header: hdr,
})