summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/stack/BUILD1
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/stack_test.go39
-rw-r--r--pkg/tcpip/stack/transport_test.go112
4 files changed, 148 insertions, 6 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
index 9ff1c8731..8a598c57d 100644
--- a/pkg/tcpip/stack/BUILD
+++ b/pkg/tcpip/stack/BUILD
@@ -43,6 +43,7 @@ go_test(
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/loopback",
"//pkg/waiter",
],
)
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 770d288cf..586ca873e 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -428,6 +428,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr
ref, ok := n.endpoints[NetworkEndpointID{dst}]
n.mu.RUnlock()
if ok && ref.tryIncRef() {
+ r.RemoteAddress = src
+ // TODO: Update the source NIC as well.
ref.ep.HandlePacket(&r, vv)
ref.decRef()
} else {
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 391319f35..163fadded 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -1039,6 +1039,45 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
}
+func TestNICForwarding(t *testing.T) {
+ // Create a stack with the fake network protocol, two NICs, each with
+ // an address.
+ s := stack.New([]string{"fakeNet"}, nil, stack.Options{})
+ s.SetForwarding(true)
+
+ id1, linkEP1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC #1 failed: %v", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress #1 failed: %v", err)
+ }
+
+ id2, linkEP2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC #2 failed: %v", err)
+ }
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress #2 failed: %v", err)
+ }
+
+ // Route all packets to address 3 to NIC 2.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x03", "\xff", "\x00", 2},
+ })
+
+ // Send a packet to address 3.
+ buf := buffer.NewView(30)
+ buf[0] = 3
+ linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView())
+
+ select {
+ case <-linkEP2.C:
+ default:
+ t.Fatal("Packet not forwarded")
+ }
+}
+
func init() {
stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
return &fakeNetworkProtocol{}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 022207081..da460db77 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -20,6 +20,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -42,6 +43,9 @@ type fakeTransportEndpoint struct {
proto *fakeTransportProtocol
peerAddr tcpip.Address
route stack.Route
+
+ // acceptQueue is non-nil iff bound.
+ acceptQueue []fakeTransportEndpoint
}
func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
@@ -132,11 +136,27 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
return nil
}
-func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
- return nil, nil, nil
-}
-
-func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ if len(f.acceptQueue) == 0 {
+ return nil, nil, nil
+ }
+ a := f.acceptQueue[0]
+ f.acceptQueue = f.acceptQueue[1:]
+ return &a, nil, nil
+}
+
+func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ if err := f.stack.RegisterTransportEndpoint(
+ a.NIC,
+ []tcpip.NetworkProtocolNumber{fakeNetNumber},
+ fakeTransNumber,
+ stack.TransportEndpointID{LocalAddress: a.Addr},
+ f,
+ false,
+ ); err != nil {
+ return err
+ }
+ f.acceptQueue = []fakeTransportEndpoint{}
return commit()
}
@@ -148,9 +168,19 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro
return tcpip.FullAddress{}, nil
}
-func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) {
+func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) {
// Increment the number of received packets.
f.proto.packetCount++
+ if f.acceptQueue != nil {
+ f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{
+ id: id,
+ stack: f.stack,
+ netProto: f.netProto,
+ proto: f.proto,
+ peerAddr: r.RemoteAddress,
+ route: r.Clone(),
+ })
+ }
}
func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) {
@@ -415,6 +445,76 @@ func TestTransportOptions(t *testing.T) {
}
}
+func TestTransportForwarding(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{})
+ s.SetForwarding(true)
+
+ // TODO: Change this to a channel NIC.
+ id1 := loopback.New()
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC #1 failed: %v", err)
+ }
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress #1 failed: %v", err)
+ }
+
+ id2, linkEP2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC #2 failed: %v", err)
+ }
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress #2 failed: %v", err)
+ }
+
+ // Route all packets to address 3 to NIC 2 and all packets to address
+ // 1 to NIC 1.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x03", "\xff", "\x00", 2},
+ {"\x01", "\xff", "\x00", 1},
+ })
+
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}, func() *tcpip.Error { return nil }); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Send a packet to address 1 from address 3.
+ req := buffer.NewView(30)
+ req[0] = 1
+ req[1] = 3
+ req[2] = byte(fakeTransNumber)
+ linkEP2.Inject(fakeNetNumber, req.ToVectorisedView())
+
+ aep, _, err := ep.Accept()
+ if err != nil || aep == nil {
+ t.Fatalf("Accept failed: %v, %v", aep, err)
+ }
+
+ resp := buffer.NewView(30)
+ if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ var p channel.PacketInfo
+ select {
+ case p = <-linkEP2.C:
+ default:
+ t.Fatal("Response packet not forwarded")
+ }
+
+ if dst := p.Header[0]; dst != 3 {
+ t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst)
+ }
+ if src := p.Header[1]; src != 1 {
+ t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src)
+ }
+}
+
func init() {
stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
return &fakeTransportProtocol{}