summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go17
-rw-r--r--test/syscalls/linux/raw_socket.cc63
2 files changed, 78 insertions, 2 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index ab5da987a..b3d8951ff 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -455,8 +455,21 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error {
}
// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
-func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
- return tcpip.FullAddress{}, &tcpip.ErrNotSupported{}
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ addr := e.BindAddr
+ if e.connected {
+ addr = e.route.LocalAddress()
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.RegisterNICID,
+ Addr: addr,
+ // Linux returns the protocol in the port field.
+ Port: uint16(e.TransProto),
+ }, nil
}
// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc
index 69616b400..f8798bc76 100644
--- a/test/syscalls/linux/raw_socket.cc
+++ b/test/syscalls/linux/raw_socket.cc
@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <arpa/inet.h>
#include <linux/capability.h>
#include <linux/filter.h>
#include <netinet/in.h>
@@ -76,6 +77,20 @@ class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> {
return 0;
}
+ uint16_t Port(struct sockaddr* s) {
+ if (Family() == AF_INET) {
+ return ntohs(reinterpret_cast<struct sockaddr_in*>(s)->sin_port);
+ }
+ return ntohs(reinterpret_cast<struct sockaddr_in6*>(s)->sin6_port);
+ }
+
+ void* Addr(struct sockaddr* s) {
+ if (Family() == AF_INET) {
+ return &(reinterpret_cast<struct sockaddr_in*>(s)->sin_addr);
+ }
+ return &(reinterpret_cast<struct sockaddr_in6*>(s)->sin6_addr);
+ }
+
// The socket used for both reading and writing.
int s_;
@@ -181,6 +196,54 @@ TEST_P(RawSocketTest, FailAccept) {
ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP));
}
+TEST_P(RawSocketTest, BindThenGetSockName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_);
+ ASSERT_THAT(bind(s_, addr, AddrLen()), SyscallSucceeds());
+ struct sockaddr_storage saddr_storage;
+ struct sockaddr* saddr = reinterpret_cast<struct sockaddr*>(&saddr_storage);
+ socklen_t saddrlen = AddrLen();
+ ASSERT_THAT(getsockname(s_, saddr, &saddrlen), SyscallSucceeds());
+ ASSERT_EQ(saddrlen, AddrLen());
+
+ // The port is expected to hold the protocol number.
+ EXPECT_EQ(Port(saddr), Protocol());
+
+ char addrbuf[INET6_ADDRSTRLEN], saddrbuf[INET6_ADDRSTRLEN];
+ const char* addrstr =
+ inet_ntop(addr->sa_family, Addr(addr), addrbuf, sizeof(addrbuf));
+ ASSERT_NE(addrstr, nullptr);
+ const char* saddrstr =
+ inet_ntop(saddr->sa_family, Addr(saddr), saddrbuf, sizeof(saddrbuf));
+ ASSERT_NE(saddrstr, nullptr);
+ EXPECT_STREQ(saddrstr, addrstr);
+}
+
+TEST_P(RawSocketTest, ConnectThenGetSockName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));
+
+ struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_);
+ ASSERT_THAT(connect(s_, addr, AddrLen()), SyscallSucceeds());
+ struct sockaddr_storage saddr_storage;
+ struct sockaddr* saddr = reinterpret_cast<struct sockaddr*>(&saddr_storage);
+ socklen_t saddrlen = AddrLen();
+ ASSERT_THAT(getsockname(s_, saddr, &saddrlen), SyscallSucceeds());
+ ASSERT_EQ(saddrlen, AddrLen());
+
+ // The port is expected to hold the protocol number.
+ EXPECT_EQ(Port(saddr), Protocol());
+
+ char addrbuf[INET6_ADDRSTRLEN], saddrbuf[INET6_ADDRSTRLEN];
+ const char* addrstr =
+ inet_ntop(addr->sa_family, Addr(addr), addrbuf, sizeof(addrbuf));
+ ASSERT_NE(addrstr, nullptr);
+ const char* saddrstr =
+ inet_ntop(saddr->sa_family, Addr(saddr), saddrbuf, sizeof(saddrbuf));
+ ASSERT_NE(saddrstr, nullptr);
+ EXPECT_STREQ(saddrstr, addrstr);
+}
+
// Test that getpeername() returns nothing before connect().
TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) {
SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW)));