diff options
Diffstat (limited to 'pkg/tcpip')
-rwxr-xr-x | pkg/tcpip/transport/packet/endpoint.go | 21 | ||||
-rwxr-xr-x | pkg/tcpip/transport/packet/packet_state_autogen.go | 2 |
2 files changed, 22 insertions, 1 deletions
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 5722815e9..09a1cd436 100755 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -76,6 +76,7 @@ type endpoint struct { sndBufSize int closed bool stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool } // NewEndpoint returns a new packet endpoint. @@ -125,6 +126,7 @@ func (ep *endpoint) Close() { } ep.closed = true + ep.bound = false ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) } @@ -216,7 +218,24 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." // - packet(7). - return tcpip.ErrNotSupported + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound { + return tcpip.ErrAlreadyBound + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + + return nil } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go index 1d90a383f..8ff339e08 100755 --- a/pkg/tcpip/transport/packet/packet_state_autogen.go +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -38,6 +38,7 @@ func (x *endpoint) save(m state.Map) { m.Save("rcvClosed", &x.rcvClosed) m.Save("sndBufSize", &x.sndBufSize) m.Save("closed", &x.closed) + m.Save("bound", &x.bound) } func (x *endpoint) load(m state.Map) { @@ -50,6 +51,7 @@ func (x *endpoint) load(m state.Map) { m.Load("rcvClosed", &x.rcvClosed) m.Load("sndBufSize", &x.sndBufSize) m.Load("closed", &x.closed) + m.Load("bound", &x.bound) m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) m.AfterLoad(x.afterLoad) } |