summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netlink
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netlink')
-rw-r--r--pkg/sentry/socket/netlink/message_test.go45
-rw-r--r--pkg/sentry/socket/netlink/socket.go2
2 files changed, 10 insertions, 37 deletions
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index 968968469..1604b2792 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -25,33 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
-type dummyNetlinkMsg struct {
- marshal.StubMarshallable
- Foo uint16
-}
-
-func (*dummyNetlinkMsg) SizeBytes() int {
- return 2
-}
-
-func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
- p := primitive.Uint16(m.Foo)
- p.MarshalUnsafe(dst)
-}
-
-func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
- var p primitive.Uint16
- p.UnmarshalUnsafe(src)
- m.Foo = uint16(p)
-}
-
func TestParseMessage(t *testing.T) {
+ dummyNetlinkMsg := primitive.Uint16(0x3130)
tests := []struct {
desc string
input []byte
header linux.NetlinkMessageHeader
- dataMsg *dummyNetlinkMsg
+ dataMsg marshal.Marshallable
restLen int
ok bool
}{
@@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 1,
ok: true,
},
@@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) {
t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
}
- dataMsg := &dummyNetlinkMsg{}
- _, dataOk := msg.GetData(dataMsg)
+ var dataMsg primitive.Uint16
+ _, dataOk := msg.GetData(&dataMsg)
if !dataOk {
t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
- } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) {
t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 267155807..19c8f340d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
+ sa.UnmarshalUnsafe(b)
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument