summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/control/control.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/control/control.go')
-rw-r--r--pkg/sentry/socket/control/control.go53
1 files changed, 25 insertions, 28 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index c284efde5..5b81e8379 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,32 +343,32 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
)
}
// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
-func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip.FullAddress, buf []byte) []byte {
- p, _ := socket.ConvertAddress(family, originalDstAddress)
- level := uint32(linux.SOL_IP)
- optType := uint32(linux.IP_RECVORIGDSTADDR)
- if family == linux.AF_INET6 {
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
level = linux.SOL_IPV6
optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
}
return putCmsgStruct(
- buf, level, optType, t.Arch().Width(), p)
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
}
// PackControlMessages packs control messages into the given buffer.
@@ -378,7 +377,7 @@ func PackOriginalDstAddress(t *kernel.Task, family int, originalDstAddress tcpip
//
// Note that some control messages may be truncated if they do not fit under
// the capacity of buf.
-func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessages, buf []byte) []byte {
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
if cmsgs.IP.HasTimestamp {
buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
}
@@ -397,11 +396,11 @@ func PackControlMessages(t *kernel.Task, family int, cmsgs socket.ControlMessage
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
}
- if cmsgs.IP.HasOriginalDstAddress {
- buf = PackOriginalDstAddress(t, family, cmsgs.IP.OriginalDstAddress, buf)
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
return buf
@@ -433,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -529,7 +526,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
i += binary.AlignUp(length, width)
default: