summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/abi/linux/linux_abi_autogen_unsafe.go4
-rw-r--r--pkg/syserr/netstack.go2
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go25
-rw-r--r--pkg/tcpip/transport/packet/endpoint_state.go19
-rw-r--r--pkg/tcpip/transport/packet/packet_state_autogen.go7
5 files changed, 52 insertions, 5 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go
index d50ebe915..ac656f14c 100644
--- a/pkg/abi/linux/linux_abi_autogen_unsafe.go
+++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go
@@ -148,7 +148,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) {
// CopyOutN implements marshal.Marshallable.CopyOutN.
//go:nosplit
func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) {
- if !s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() {
+ if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
s.MarshalBytes(buf) // escapes: fallback.
@@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) {
// CopyIn implements marshal.Marshallable.CopyIn.
//go:nosplit
func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) {
- if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() {
+ if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() {
// Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes.
buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay.
length, err := task.CopyInBytes(addr, buf) // escapes: okay.
diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go
index 8ff922c69..5ae10939d 100644
--- a/pkg/syserr/netstack.go
+++ b/pkg/syserr/netstack.go
@@ -22,7 +22,7 @@ import (
// Mapping for tcpip.Error types.
var (
ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL)
- ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL)
+ ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV)
ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV)
ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT)
ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST)
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 57b7f5c19..92b487381 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -79,6 +79,11 @@ type endpoint struct {
closed bool
stats tcpip.TransportEndpointStats `state:"nosave"`
bound bool
+ boundNIC tcpip.NICID
+
+ // lastErrorMu protects lastError.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
}
// NewEndpoint returns a new packet endpoint.
@@ -229,12 +234,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
ep.mu.Lock()
defer ep.mu.Unlock()
- if ep.bound {
- return tcpip.ErrAlreadyBound
+ if ep.bound && ep.boundNIC == addr.NIC {
+ // If the NIC being bound is the same then just return success.
+ return nil
}
// Unregister endpoint with all the nics.
ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep)
+ ep.bound = false
// Bind endpoint to receive packets from specific interface.
if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil {
@@ -242,6 +249,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
}
ep.bound = true
+ ep.boundNIC = addr.NIC
return nil
}
@@ -336,8 +344,21 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
}
+func (ep *endpoint) takeLastError() *tcpip.Error {
+ ep.lastErrorMu.Lock()
+ defer ep.lastErrorMu.Unlock()
+
+ err := ep.lastError
+ ep.lastError = nil
+ return err
+}
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return ep.takeLastError()
+ }
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go
index 9b88f17e4..e2fa96d17 100644
--- a/pkg/tcpip/transport/packet/endpoint_state.go
+++ b/pkg/tcpip/transport/packet/endpoint_state.go
@@ -15,6 +15,7 @@
package packet
import (
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/stack"
)
@@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() {
panic(*err)
}
}
+
+// saveLastError is invoked by stateify.
+func (ep *endpoint) saveLastError() string {
+ if ep.lastError == nil {
+ return ""
+ }
+
+ return ep.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (ep *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ ep.lastError = tcpip.StringToError(s)
+}
diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go
index 06ad0b846..b7fa1cdc9 100644
--- a/pkg/tcpip/transport/packet/packet_state_autogen.go
+++ b/pkg/tcpip/transport/packet/packet_state_autogen.go
@@ -58,6 +58,8 @@ func (x *endpoint) StateFields() []string {
"sndBufSizeMax",
"closed",
"bound",
+ "boundNIC",
+ "lastError",
}
}
@@ -65,6 +67,8 @@ func (x *endpoint) StateSave(m state.Sink) {
x.beforeSave()
var rcvBufSizeMax int = x.saveRcvBufSizeMax()
m.SaveValue(5, rcvBufSizeMax)
+ var lastError string = x.saveLastError()
+ m.SaveValue(13, lastError)
m.Save(0, &x.TransportEndpointInfo)
m.Save(1, &x.netProto)
m.Save(2, &x.waiterQueue)
@@ -76,6 +80,7 @@ func (x *endpoint) StateSave(m state.Sink) {
m.Save(9, &x.sndBufSizeMax)
m.Save(10, &x.closed)
m.Save(11, &x.bound)
+ m.Save(12, &x.boundNIC)
}
func (x *endpoint) StateLoad(m state.Source) {
@@ -90,7 +95,9 @@ func (x *endpoint) StateLoad(m state.Source) {
m.Load(9, &x.sndBufSizeMax)
m.Load(10, &x.closed)
m.Load(11, &x.bound)
+ m.Load(12, &x.boundNIC)
m.LoadValue(5, new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.LoadValue(13, new(string), func(y interface{}) { x.loadLastError(y.(string)) })
m.AfterLoad(x.afterLoad)
}