diff options
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint_state.go | 1 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint_state.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/udp_test.go | 58 |
4 files changed, 55 insertions, 23 deletions
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index 212d2513a..b1e249bff 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -213,6 +213,7 @@ func loadError(s string) *tcpip.Error { tcpip.ErrInvalidOptionValue, tcpip.ErrNoLinkAddress, tcpip.ErrBadAddress, + tcpip.ErrNetworkUnreachable, } messageToError = make(map[string]*tcpip.Error) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index f86fc6d5a..6fcddd028 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -62,7 +62,6 @@ type endpoint struct { id stack.TransportEndpointID state endpointState bindNICID tcpip.NICID - bindAddr tcpip.Address regNICID tcpip.NICID route stack.Route `state:"manual"` dstPort uint16 @@ -267,13 +266,13 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tc toCopy := *to to = &toCopy - netProto, err := e.checkV4Mapped(to, true) + netProto, err := e.checkV4Mapped(to, false) if err != nil { return 0, err } // Find the enpoint. - r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto) + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, to.Addr, netProto) if err != nil { return 0, err } @@ -439,11 +438,16 @@ func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (t if addr.Addr == "\x00\x00\x00\x00" { addr.Addr = "" } + + // Fail if we are bound to an IPv6 address. + if !allowMismatch && len(e.id.LocalAddress) == 16 { + return 0, tcpip.ErrNetworkUnreachable + } } // Fail if we're bound to an address length different from the one we're // checking. - if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { return 0, tcpip.ErrInvalidEndpointState } @@ -485,7 +489,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Find a route to the desired destination. - r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto) + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) if err != nil { return err } @@ -605,7 +609,7 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error return tcpip.ErrInvalidEndpointState } - netProto, err := e.checkV4Mapped(&addr, false) + netProto, err := e.checkV4Mapped(&addr, true) if err != nil { return err } @@ -670,7 +674,6 @@ func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcp } e.bindNICID = addr.NIC - e.bindAddr = addr.Addr return nil } diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go index e20d59ca3..93784fb05 100644 --- a/pkg/tcpip/transport/udp/endpoint_state.go +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -72,7 +72,7 @@ func (e *endpoint) afterLoad() { var err *tcpip.Error if e.state == stateConnected { - e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, netProto) + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto) if err != nil { panic(*err) } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 1eb9ecb80..cc342c69b 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -401,7 +401,7 @@ func TestV4ReadOnV4(t *testing.T) { testV4Read(c) } -func testDualWrite(c *testContext) uint16 { +func testV4Write(c *testContext) uint16 { // Write to V4 mapped address. payload := buffer.View(newPayload()) n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ @@ -423,16 +423,18 @@ func testDualWrite(c *testContext) uint16 { ), ) - port := udp.SourcePort() - // Check the payload. if !bytes.Equal(payload, udp.Payload()) { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } + return udp.SourcePort() +} + +func testV6Write(c *testContext) uint16 { // Write to v6 address. - payload = buffer.View(newPayload()) - n, err = c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + payload := buffer.View(newPayload()) + n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, }) if err != nil { @@ -442,14 +444,12 @@ func testDualWrite(c *testContext) uint16 { c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) } - // Check that we received the packet, and that the source port is the - // same as the one used in ipv4. - b = c.getV6Packet() - udp = header.UDP(header.IPv6(b).Payload()) + // Check that we received the packet. + b := c.getV6Packet() + udp := header.UDP(header.IPv6(b).Payload()) checker.IPv6(c.t, b, checker.UDP( checker.DstPort(testPort), - checker.SrcPort(port), ), ) @@ -458,7 +458,17 @@ func testDualWrite(c *testContext) uint16 { c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) } - return port + return udp.SourcePort() +} + +func testDualWrite(c *testContext) uint16 { + v4Port := testV4Write(c) + v6Port := testV6Write(c) + if v4Port != v6Port { + c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) + } + + return v4Port } func TestDualWriteUnbound(t *testing.T) { @@ -498,7 +508,16 @@ func TestDualWriteConnectedToV6(t *testing.T) { c.t.Fatalf("Bind failed: %v", err) } - testDualWrite(c) + testV6Write(c) + + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != tcpip.ErrNetworkUnreachable { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNetworkUnreachable) + } } func TestDualWriteConnectedToV4Mapped(t *testing.T) { @@ -512,7 +531,16 @@ func TestDualWriteConnectedToV4Mapped(t *testing.T) { c.t.Fatalf("Bind failed: %v", err) } - testDualWrite(c) + testV4Write(c) + + // Write to v6 address. + payload := buffer.View(newPayload()) + _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != tcpip.ErrInvalidEndpointState { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) + } } func TestV4WriteOnV6Only(t *testing.T) { @@ -547,8 +575,8 @@ func TestV6WriteOnBoundToV4Mapped(t *testing.T) { _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, }) - if err != tcpip.ErrNoRoute { - c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) + if err != tcpip.ErrInvalidEndpointState { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrInvalidEndpointState) } } |