diff options
-rw-r--r-- | pkg/tcpip/stack/BUILD | 1 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 39 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 112 |
4 files changed, 148 insertions, 6 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9ff1c8731..8a598c57d 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -43,6 +43,7 @@ go_test( "//pkg/tcpip/buffer", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 770d288cf..586ca873e 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -428,6 +428,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr ref, ok := n.endpoints[NetworkEndpointID{dst}] n.mu.RUnlock() if ok && ref.tryIncRef() { + r.RemoteAddress = src + // TODO: Update the source NIC as well. ref.ep.HandlePacket(&r, vv) ref.decRef() } else { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 391319f35..163fadded 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -1039,6 +1039,45 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } } +func TestNICForwarding(t *testing.T) { + // Create a stack with the fake network protocol, two NICs, each with + // an address. + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s.SetForwarding(true) + + id1, linkEP1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC #1 failed: %v", err) + } + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress #1 failed: %v", err) + } + + id2, linkEP2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC #2 failed: %v", err) + } + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress #2 failed: %v", err) + } + + // Route all packets to address 3 to NIC 2. + s.SetRouteTable([]tcpip.Route{ + {"\x03", "\xff", "\x00", 2}, + }) + + // Send a packet to address 3. + buf := buffer.NewView(30) + buf[0] = 3 + linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + + select { + case <-linkEP2.C: + default: + t.Fatal("Packet not forwarded") + } +} + func init() { stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { return &fakeNetworkProtocol{} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 022207081..da460db77 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -20,6 +20,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -42,6 +43,9 @@ type fakeTransportEndpoint struct { proto *fakeTransportProtocol peerAddr tcpip.Address route stack.Route + + // acceptQueue is non-nil iff bound. + acceptQueue []fakeTransportEndpoint } func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { @@ -132,11 +136,27 @@ func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { return nil } -func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { - return nil, nil, nil -} - -func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { +func (f *fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + if len(f.acceptQueue) == 0 { + return nil, nil, nil + } + a := f.acceptQueue[0] + f.acceptQueue = f.acceptQueue[1:] + return &a, nil, nil +} + +func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + if err := f.stack.RegisterTransportEndpoint( + a.NIC, + []tcpip.NetworkProtocolNumber{fakeNetNumber}, + fakeTransNumber, + stack.TransportEndpointID{LocalAddress: a.Addr}, + f, + false, + ); err != nil { + return err + } + f.acceptQueue = []fakeTransportEndpoint{} return commit() } @@ -148,9 +168,19 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) { +func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ buffer.VectorisedView) { // Increment the number of received packets. f.proto.packetCount++ + if f.acceptQueue != nil { + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + id: id, + stack: f.stack, + netProto: f.netProto, + proto: f.proto, + peerAddr: r.RemoteAddress, + route: r.Clone(), + }) + } } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, buffer.VectorisedView) { @@ -415,6 +445,76 @@ func TestTransportOptions(t *testing.T) { } } +func TestTransportForwarding(t *testing.T) { + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s.SetForwarding(true) + + // TODO: Change this to a channel NIC. + id1 := loopback.New() + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC #1 failed: %v", err) + } + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress #1 failed: %v", err) + } + + id2, linkEP2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC #2 failed: %v", err) + } + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress #2 failed: %v", err) + } + + // Route all packets to address 3 to NIC 2 and all packets to address + // 1 to NIC 1. + s.SetRouteTable([]tcpip.Route{ + {"\x03", "\xff", "\x00", 2}, + {"\x01", "\xff", "\x00", 1}, + }) + + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Addr: "\x01", NIC: 1}, func() *tcpip.Error { return nil }); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Send a packet to address 1 from address 3. + req := buffer.NewView(30) + req[0] = 1 + req[1] = 3 + req[2] = byte(fakeTransNumber) + linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + + aep, _, err := ep.Accept() + if err != nil || aep == nil { + t.Fatalf("Accept failed: %v, %v", aep, err) + } + + resp := buffer.NewView(30) + if _, _, err := aep.Write(tcpip.SlicePayload(resp), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + var p channel.PacketInfo + select { + case p = <-linkEP2.C: + default: + t.Fatal("Response packet not forwarded") + } + + if dst := p.Header[0]; dst != 3 { + t.Errorf("Response packet has incorrect destination addresss: got = %d, want = 3", dst) + } + if src := p.Header[1]; src != 1 { + t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) + } +} + func init() { stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { return &fakeTransportProtocol{} |