summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport/packet/endpoint.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport/packet/endpoint.go')
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go85
1 files changed, 59 insertions, 26 deletions
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 73cdaa265..23158173d 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -25,12 +25,10 @@
package packet
import (
- "sync"
-
+ "gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -41,10 +39,6 @@ type packet struct {
// data holds the actual packet data, including any headers and
// payload.
data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
- // views is pre-allocated space to back data. As long as the packet is
- // made up of fewer than 8 buffer.Views, no extra allocation is
- // necessary to store packet data.
- views [8]buffer.View `state:"nosave"`
// timestampNS is the unix time at which the packet was received.
timestampNS int64
// senderAddr is the network address of the sender.
@@ -81,6 +75,7 @@ type endpoint struct {
sndBufSize int
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
+ bound bool
}
// NewEndpoint returns a new packet endpoint.
@@ -103,6 +98,11 @@ func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumb
return ep, nil
}
+// Abort implements stack.TransportEndpoint.Abort.
+func (ep *endpoint) Abort() {
+ ep.Close()
+}
+
// Close implements tcpip.Endpoint.Close.
func (ep *endpoint) Close() {
ep.mu.Lock()
@@ -125,6 +125,7 @@ func (ep *endpoint) Close() {
}
ep.closed = true
+ ep.bound = false
ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
}
@@ -132,7 +133,7 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// IPTables implements tcpip.Endpoint.IPTables.
-func (ep *endpoint) IPTables() (iptables.IPTables, error) {
+func (ep *endpoint) IPTables() (stack.IPTables, error) {
return ep.stack.IPTables(), nil
}
@@ -216,7 +217,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
// sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex."
// - packet(7).
- return tcpip.ErrNotSupported
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.bound {
+ return tcpip.ErrAlreadyBound
+ }
+
+ // Unregister endpoint with all the nics.
+ ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+
+ // Bind endpoint to receive packets from specific interface.
+ if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
+ return err
+ }
+
+ ep.bound = true
+
+ return nil
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
@@ -251,17 +269,17 @@ func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
// used with SetSockOpt, and this function always returns
// tcpip.ErrNotSupported.
func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- return tcpip.ErrNotSupported
+ return tcpip.ErrUnknownProtocolOption
}
-// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
-func (ep *endpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool.
+func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
return tcpip.ErrUnknownProtocolOption
}
-// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
-func (ep *endpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
- return 0, tcpip.ErrNotSupported
+// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt.
+func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
@@ -269,8 +287,18 @@ func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
+// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
+func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ return false, tcpip.ErrNotSupported
+}
+
+// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt.
+func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ return 0, tcpip.ErrNotSupported
+}
+
// HandlePacket implements stack.PacketEndpoint.HandlePacket.
-func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, vv buffer.VectorisedView, ethHeader buffer.View) {
+func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt stack.PacketBuffer) {
ep.rcvMu.Lock()
// Drop the packet if our buffer is currently full.
@@ -293,28 +321,29 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress,
// Push new packet into receive list and increment the buffer size.
var packet packet
// TODO(b/129292371): Return network protocol.
- if len(ethHeader) > 0 {
+ if len(pkt.LinkHeader) > 0 {
// Get info directly from the ethernet header.
- hdr := header.Ethernet(ethHeader)
+ hdr := header.Ethernet(pkt.LinkHeader)
packet.senderAddr = tcpip.FullAddress{
- NIC: nicid,
+ NIC: nicID,
Addr: tcpip.Address(hdr.SourceAddress()),
}
} else {
// Guess the would-be ethernet header.
packet.senderAddr = tcpip.FullAddress{
- NIC: nicid,
+ NIC: nicID,
Addr: tcpip.Address(localAddr),
}
}
if ep.cooked {
// Cooked packets can simply be queued.
- packet.data = vv.Clone(packet.views[:])
+ packet.data = pkt.Data
} else {
// Raw packets need their ethernet headers prepended before
// queueing.
- if len(ethHeader) == 0 {
+ var linkHeader buffer.View
+ if len(pkt.LinkHeader) == 0 {
// We weren't provided with an actual ethernet header,
// so fake one.
ethFields := header.EthernetFields{
@@ -324,11 +353,13 @@ func (ep *endpoint) HandlePacket(nicid tcpip.NICID, localAddr tcpip.LinkAddress,
}
fakeHeader := make(header.Ethernet, header.EthernetMinimumSize)
fakeHeader.Encode(&ethFields)
- ethHeader = buffer.View(fakeHeader)
+ linkHeader = buffer.View(fakeHeader)
+ } else {
+ linkHeader = append(buffer.View(nil), pkt.LinkHeader...)
}
- combinedVV := buffer.View(ethHeader).ToVectorisedView()
- combinedVV.Append(vv)
- packet.data = combinedVV.Clone(packet.views[:])
+ combinedVV := linkHeader.ToVectorisedView()
+ combinedVV.Append(pkt.Data)
+ packet.data = combinedVV
}
packet.timestampNS = ep.stack.NowNanoseconds()
@@ -361,3 +392,5 @@ func (ep *endpoint) Info() tcpip.EndpointInfo {
func (ep *endpoint) Stats() tcpip.EndpointStats {
return &ep.stats
}
+
+func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {}