diff options
-rw-r--r-- | pkg/abi/linux/linux_abi_autogen_unsafe.go | 4 | ||||
-rw-r--r-- | pkg/syserr/netstack.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 25 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint_state.go | 19 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/packet_state_autogen.go | 7 |
5 files changed, 52 insertions, 5 deletions
diff --git a/pkg/abi/linux/linux_abi_autogen_unsafe.go b/pkg/abi/linux/linux_abi_autogen_unsafe.go index d50ebe915..ac656f14c 100644 --- a/pkg/abi/linux/linux_abi_autogen_unsafe.go +++ b/pkg/abi/linux/linux_abi_autogen_unsafe.go @@ -148,7 +148,7 @@ func (s *Statx) UnmarshalUnsafe(src []byte) { // CopyOutN implements marshal.Marshallable.CopyOutN. //go:nosplit func (s *Statx) CopyOutN(task marshal.Task, addr usermem.Addr, limit int) (int, error) { - if !s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { // Type Statx doesn't have a packed layout in memory, fall back to MarshalBytes. buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. s.MarshalBytes(buf) // escapes: fallback. @@ -178,7 +178,7 @@ func (s *Statx) CopyOut(task marshal.Task, addr usermem.Addr) (int, error) { // CopyIn implements marshal.Marshallable.CopyIn. //go:nosplit func (s *Statx) CopyIn(task marshal.Task, addr usermem.Addr) (int, error) { - if !s.Ctime.Packed() && s.Mtime.Packed() && s.Atime.Packed() && s.Btime.Packed() { + if !s.Atime.Packed() && s.Btime.Packed() && s.Ctime.Packed() && s.Mtime.Packed() { // Type Statx doesn't have a packed layout in memory, fall back to UnmarshalBytes. buf := task.CopyScratchBuffer(s.SizeBytes()) // escapes: okay. length, err := task.CopyInBytes(addr, buf) // escapes: okay. diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 8ff922c69..5ae10939d 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -22,7 +22,7 @@ import ( // Mapping for tcpip.Error types. var ( ErrUnknownProtocol = New(tcpip.ErrUnknownProtocol.String(), linux.EINVAL) - ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.EINVAL) + ErrUnknownNICID = New(tcpip.ErrUnknownNICID.String(), linux.ENODEV) ErrUnknownDevice = New(tcpip.ErrUnknownDevice.String(), linux.ENODEV) ErrUnknownProtocolOption = New(tcpip.ErrUnknownProtocolOption.String(), linux.ENOPROTOOPT) ErrDuplicateNICID = New(tcpip.ErrDuplicateNICID.String(), linux.EEXIST) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 57b7f5c19..92b487381 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -79,6 +79,11 @@ type endpoint struct { closed bool stats tcpip.TransportEndpointStats `state:"nosave"` bound bool + boundNIC tcpip.NICID + + // lastErrorMu protects lastError. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` } // NewEndpoint returns a new packet endpoint. @@ -229,12 +234,14 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - if ep.bound { - return tcpip.ErrAlreadyBound + if ep.bound && ep.boundNIC == addr.NIC { + // If the NIC being bound is the same then just return success. + return nil } // Unregister endpoint with all the nics. ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + ep.bound = false // Bind endpoint to receive packets from specific interface. if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { @@ -242,6 +249,7 @@ func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { } ep.bound = true + ep.boundNIC = addr.NIC return nil } @@ -336,8 +344,21 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } } +func (ep *endpoint) takeLastError() *tcpip.Error { + ep.lastErrorMu.Lock() + defer ep.lastErrorMu.Unlock() + + err := ep.lastError + ep.lastError = nil + return err +} + // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return ep.takeLastError() + } return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go index 9b88f17e4..e2fa96d17 100644 --- a/pkg/tcpip/transport/packet/endpoint_state.go +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -15,6 +15,7 @@ package packet import ( + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -70,3 +71,21 @@ func (ep *endpoint) afterLoad() { panic(*err) } } + +// saveLastError is invoked by stateify. +func (ep *endpoint) saveLastError() string { + if ep.lastError == nil { + return "" + } + + return ep.lastError.String() +} + +// loadLastError is invoked by stateify. +func (ep *endpoint) loadLastError(s string) { + if s == "" { + return + } + + ep.lastError = tcpip.StringToError(s) +} diff --git a/pkg/tcpip/transport/packet/packet_state_autogen.go b/pkg/tcpip/transport/packet/packet_state_autogen.go index 06ad0b846..b7fa1cdc9 100644 --- a/pkg/tcpip/transport/packet/packet_state_autogen.go +++ b/pkg/tcpip/transport/packet/packet_state_autogen.go @@ -58,6 +58,8 @@ func (x *endpoint) StateFields() []string { "sndBufSizeMax", "closed", "bound", + "boundNIC", + "lastError", } } @@ -65,6 +67,8 @@ func (x *endpoint) StateSave(m state.Sink) { x.beforeSave() var rcvBufSizeMax int = x.saveRcvBufSizeMax() m.SaveValue(5, rcvBufSizeMax) + var lastError string = x.saveLastError() + m.SaveValue(13, lastError) m.Save(0, &x.TransportEndpointInfo) m.Save(1, &x.netProto) m.Save(2, &x.waiterQueue) @@ -76,6 +80,7 @@ func (x *endpoint) StateSave(m state.Sink) { m.Save(9, &x.sndBufSizeMax) m.Save(10, &x.closed) m.Save(11, &x.bound) + m.Save(12, &x.boundNIC) } func (x *endpoint) StateLoad(m state.Source) { @@ -90,7 +95,9 @@ func (x *endpoint) StateLoad(m state.Source) { m.Load(9, &x.sndBufSizeMax) m.Load(10, &x.closed) m.Load(11, &x.bound) + m.Load(12, &x.boundNIC) m.LoadValue(5, new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.LoadValue(13, new(string), func(y interface{}) { x.loadLastError(y.(string)) }) m.AfterLoad(x.afterLoad) } |