summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2020-06-27 02:11:44 +0000
committergVisor bot <gvisor-bot@google.com>2020-06-27 02:11:44 +0000
commit934a1c32824f8e442640948d5907d5ad3a65215a (patch)
tree31702c7a5e7b16aae094dcb9e24e76bb18ace1cb
parent5bc6da314faf3f81b3bf089bb0d269dd642e40eb (diff)
parent66d1665441461a5226ba0c884e22888d58f393b6 (diff)
Merge release-20200622.1-34-g66d166544 (automated)
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go2
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go61
2 files changed, 27 insertions, 36 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index ac656f14c..285e270b7 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
+ if !s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index dd514d397..766c7648e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -94,7 +94,7 @@ func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, trans
}
func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) {
- if netProto != header.IPv4ProtocolNumber {
+ if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber {
return nil, tcpip.ErrUnknownProtocol
}
@@ -215,6 +215,11 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
// Write implements tcpip.Endpoint.Write.
func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
+ // We can create, but not write to, unassociated IPv6 endpoints.
+ if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
n, ch, err := e.write(p, opts)
switch err {
case nil:
@@ -319,12 +324,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return 0, nil, tcpip.ErrNoRoute
}
- // We don't support IPv6 yet, so this has to be an IPv4 address.
- if len(opts.To.Addr) != header.IPv4AddressSize {
- e.mu.RUnlock()
- return 0, nil, tcpip.ErrInvalidEndpointState
- }
-
// Find the route to the destination. If BindAddress is 0,
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
@@ -354,17 +353,13 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}
}
- switch e.NetProto {
- case header.IPv4ProtocolNumber:
- if !e.associated {
- if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
- Data: buffer.View(payloadBytes).ToVectorisedView(),
- }); err != nil {
- return 0, nil, err
- }
- break
+ if !e.associated {
+ if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{
+ Data: buffer.View(payloadBytes).ToVectorisedView(),
+ }); err != nil {
+ return 0, nil, err
}
-
+ } else {
hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{
Header: hdr,
@@ -373,9 +368,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
}); err != nil {
return 0, nil, err
}
-
- default:
- return 0, nil, tcpip.ErrUnknownProtocol
}
return int64(len(payloadBytes)), nil, nil
@@ -400,11 +392,6 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
return tcpip.ErrInvalidEndpointState
}
- // We don't support IPv6 yet.
- if len(addr.Addr) != header.IPv4AddressSize {
- return tcpip.ErrInvalidEndpointState
- }
-
nic := addr.NIC
if e.bound {
if e.BindNICID == 0 {
@@ -470,14 +457,8 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
e.mu.Lock()
defer e.mu.Unlock()
- // Callers must provide an IPv4 address or no network address (for
- // binding to a NIC, but not an address).
- if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
- return tcpip.ErrInvalidEndpointState
- }
-
// If a local address was specified, verify that it's valid.
- if len(addr.Addr) == header.IPv4AddressSize && e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
+ if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 {
return tcpip.ErrBadLocalAddress
}
@@ -680,9 +661,19 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) {
},
}
- headers := append(buffer.View(nil), pkt.NetworkHeader...)
- headers = append(headers, pkt.TransportHeader...)
- combinedVV := headers.ToVectorisedView()
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader))
+ headers = append(headers, pkt.NetworkHeader...)
+ headers = append(headers, pkt.TransportHeader...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView()
+ }
combinedVV.Append(pkt.Data)
packet.data = combinedVV
packet.timestampNS = e.stack.NowNanoseconds()