diff options
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 17 | ||||
-rw-r--r-- | test/syscalls/linux/raw_socket.cc | 63 |
2 files changed, 78 insertions, 2 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ab5da987a..b3d8951ff 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -455,8 +455,21 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) tcpip.Error { } // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. -func (*endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { - return tcpip.FullAddress{}, &tcpip.ErrNotSupported{} +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + addr := e.BindAddr + if e.connected { + addr = e.route.LocalAddress() + } + + return tcpip.FullAddress{ + NIC: e.RegisterNICID, + Addr: addr, + // Linux returns the protocol in the port field. + Port: uint16(e.TransProto), + }, nil } // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. diff --git a/test/syscalls/linux/raw_socket.cc b/test/syscalls/linux/raw_socket.cc index 69616b400..f8798bc76 100644 --- a/test/syscalls/linux/raw_socket.cc +++ b/test/syscalls/linux/raw_socket.cc @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <arpa/inet.h> #include <linux/capability.h> #include <linux/filter.h> #include <netinet/in.h> @@ -76,6 +77,20 @@ class RawSocketTest : public ::testing::TestWithParam<std::tuple<int, int>> { return 0; } + uint16_t Port(struct sockaddr* s) { + if (Family() == AF_INET) { + return ntohs(reinterpret_cast<struct sockaddr_in*>(s)->sin_port); + } + return ntohs(reinterpret_cast<struct sockaddr_in6*>(s)->sin6_port); + } + + void* Addr(struct sockaddr* s) { + if (Family() == AF_INET) { + return &(reinterpret_cast<struct sockaddr_in*>(s)->sin_addr); + } + return &(reinterpret_cast<struct sockaddr_in6*>(s)->sin6_addr); + } + // The socket used for both reading and writing. int s_; @@ -181,6 +196,54 @@ TEST_P(RawSocketTest, FailAccept) { ASSERT_THAT(accept(s_, &saddr, &addrlen), SyscallFailsWithErrno(ENOTSUP)); } +TEST_P(RawSocketTest, BindThenGetSockName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_); + ASSERT_THAT(bind(s_, addr, AddrLen()), SyscallSucceeds()); + struct sockaddr_storage saddr_storage; + struct sockaddr* saddr = reinterpret_cast<struct sockaddr*>(&saddr_storage); + socklen_t saddrlen = AddrLen(); + ASSERT_THAT(getsockname(s_, saddr, &saddrlen), SyscallSucceeds()); + ASSERT_EQ(saddrlen, AddrLen()); + + // The port is expected to hold the protocol number. + EXPECT_EQ(Port(saddr), Protocol()); + + char addrbuf[INET6_ADDRSTRLEN], saddrbuf[INET6_ADDRSTRLEN]; + const char* addrstr = + inet_ntop(addr->sa_family, Addr(addr), addrbuf, sizeof(addrbuf)); + ASSERT_NE(addrstr, nullptr); + const char* saddrstr = + inet_ntop(saddr->sa_family, Addr(saddr), saddrbuf, sizeof(saddrbuf)); + ASSERT_NE(saddrstr, nullptr); + EXPECT_STREQ(saddrstr, addrstr); +} + +TEST_P(RawSocketTest, ConnectThenGetSockName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + struct sockaddr* addr = reinterpret_cast<struct sockaddr*>(&addr_); + ASSERT_THAT(connect(s_, addr, AddrLen()), SyscallSucceeds()); + struct sockaddr_storage saddr_storage; + struct sockaddr* saddr = reinterpret_cast<struct sockaddr*>(&saddr_storage); + socklen_t saddrlen = AddrLen(); + ASSERT_THAT(getsockname(s_, saddr, &saddrlen), SyscallSucceeds()); + ASSERT_EQ(saddrlen, AddrLen()); + + // The port is expected to hold the protocol number. + EXPECT_EQ(Port(saddr), Protocol()); + + char addrbuf[INET6_ADDRSTRLEN], saddrbuf[INET6_ADDRSTRLEN]; + const char* addrstr = + inet_ntop(addr->sa_family, Addr(addr), addrbuf, sizeof(addrbuf)); + ASSERT_NE(addrstr, nullptr); + const char* saddrstr = + inet_ntop(saddr->sa_family, Addr(saddr), saddrbuf, sizeof(saddrbuf)); + ASSERT_NE(saddrstr, nullptr); + EXPECT_STREQ(saddrstr, addrstr); +} + // Test that getpeername() returns nothing before connect(). TEST_P(RawSocketTest, FailGetPeerNameBeforeConnect) { SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); |