diff options
Diffstat (limited to 'pkg/sentry/socket/netlink')
-rw-r--r-- | pkg/sentry/socket/netlink/message_test.go | 45 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 2 |
2 files changed, 10 insertions, 37 deletions
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index 968968469..1604b2792 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -25,33 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - func TestParseMessage(t *testing.T) { + dummyNetlinkMsg := primitive.Uint16(0x3130) tests := []struct { desc string input []byte header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg + dataMsg marshal.Marshallable restLen int ok bool }{ @@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 1, ok: true, }, @@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) { t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) } - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) + var dataMsg primitive.Uint16 + _, dataOk := msg.GetData(&dataMsg) if !dataOk { t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) { t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 267155807..19c8f340d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) + sa.UnmarshalUnsafe(b) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument |