File: //proc/thread-self/root/opt/go/pkg/mod/github.com/mdlayher/
[email protected]/conn_test.go
package socket_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math"
"net"
"net/netip"
"os"
"runtime"
"sync"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/mdlayher/socket/internal/sockettest"
"golang.org/x/net/nettest"
"golang.org/x/sync/errgroup"
"golang.org/x/sys/unix"
)
func TestConn(t *testing.T) {
t.Parallel()
tests := []struct {
name string
pipe nettest.MakePipe
}{
// Standard library plumbing.
{
name: "basic",
pipe: makePipe(
func() (net.Listener, error) {
return sockettest.Listen(0, nil)
},
func(addr net.Addr) (net.Conn, error) {
return sockettest.Dial(context.Background(), addr, nil)
},
),
},
// Our own implementations which have context cancelation support.
{
name: "context",
pipe: makePipe(
func() (net.Listener, error) {
l, err := sockettest.Listen(0, nil)
if err != nil {
return nil, err
}
return l.Context(context.Background()), nil
},
func(addr net.Addr) (net.Conn, error) {
ctx := context.Background()
c, err := sockettest.Dial(ctx, addr, nil)
if err != nil {
return nil, err
}
return c.Context(ctx), nil
},
),
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
nettest.TestConn(t, tt.pipe)
// Our own extensions to TestConn.
t.Run("CloseReadWrite", func(t *testing.T) { timeoutWrapper(t, tt.pipe, testCloseReadWrite) })
})
}
}
func TestDialTCPNoListener(t *testing.T) {
t.Parallel()
// See https://github.com/mdlayher/vsock/issues/47 and
// https://github.com/lxc/lxd/pull/9894 for context on this test.
//
//
// Given a (hopefully) non-existent listener on localhost, expect
// ECONNREFUSED.
_, err := sockettest.Dial(context.Background(), &net.TCPAddr{
IP: net.IPv6loopback,
Port: math.MaxUint16,
}, nil)
want := os.NewSyscallError("connect", unix.ECONNREFUSED)
if diff := cmp.Diff(want, err); diff != "" {
t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
}
}
func TestDialTCPContextCanceledBefore(t *testing.T) {
t.Parallel()
// Context is canceled before any dialing can take place.
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := sockettest.Dial(ctx, &net.TCPAddr{
IP: net.IPv6loopback,
Port: math.MaxUint16,
}, nil)
if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
}
}
var ipTests = []struct {
name string
ip netip.Addr
}{
// It appears we can dial addresses in the documentation range and
// connect will hang, which is perfect for this test case.
{
name: "IPv4",
ip: netip.MustParseAddr("192.0.2.1"),
},
{
name: "IPv6",
ip: netip.MustParseAddr("2001:db8::1"),
},
}
func TestDialTCPContextCanceledDuring(t *testing.T) {
t.Parallel()
for _, tt := range ipTests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Context is canceled during a blocking operation but without an
// explicit deadline passed on the context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
time.Sleep(1 * time.Second)
cancel()
}()
_, err := sockettest.Dial(ctx, &net.TCPAddr{
IP: tt.ip.AsSlice(),
Port: math.MaxUint16,
}, nil)
if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
}
if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
}
})
}
}
func TestDialTCPContextDeadlineExceeded(t *testing.T) {
t.Parallel()
for _, tt := range ipTests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
// Dialing is canceled after the deadline passes.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err := sockettest.Dial(ctx, &net.TCPAddr{
IP: tt.ip.AsSlice(),
Port: math.MaxUint16,
}, nil)
if errors.Is(err, unix.ENETUNREACH) || errors.Is(err, unix.EHOSTUNREACH) {
t.Skipf("skipping, no outbound %s connectivity: %v", tt.name, err)
}
if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected connect error (-want +got):\n%s", diff)
}
})
}
}
func TestListenerAcceptTCPContextCanceledBefore(t *testing.T) {
t.Parallel()
l, err := sockettest.Listen(0, nil)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
defer l.Close()
// Context is canceled before accept can take place.
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err = l.Context(ctx).Accept()
if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
}
}
func TestListenerAcceptTCPContextCanceledDuring(t *testing.T) {
t.Parallel()
l, err := sockettest.Listen(0, nil)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
defer l.Close()
// Context is canceled during a blocking operation but without an
// explicit deadline passed on the context.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go func() {
time.Sleep(1 * time.Second)
cancel()
}()
_, err = l.Context(ctx).Accept()
if diff := cmp.Diff(context.Canceled, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
}
}
func TestListenerAcceptTCPContextDeadlineExceeded(t *testing.T) {
t.Parallel()
l, err := sockettest.Listen(0, nil)
if err != nil {
t.Fatalf("failed to listen: %v", err)
}
defer l.Close()
// Accept is canceled after the deadline passes.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
_, err = l.Context(ctx).Accept()
if diff := cmp.Diff(context.DeadlineExceeded, err, cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected accept error (-want +got):\n%s", diff)
}
}
func TestListenerConnTCPContextCanceled(t *testing.T) {
t.Parallel()
l, err := sockettest.Listen(0, nil)
if err != nil {
t.Fatalf("failed to open listener: %v", err)
}
defer l.Close()
// Accept a single connection.
var eg errgroup.Group
eg.Go(func() error {
c, err := l.Accept()
if err != nil {
return fmt.Errorf("failed to accept: %v", err)
}
defer c.Close()
// Context is canceled during recvfrom.
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
b := make([]byte, 1024)
_, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
return err
})
c, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatalf("failed to dial listener: %v", err)
}
defer c.Close()
// Client never sends data, so we wait until ctx cancel and errgroup return.
if diff := cmp.Diff(context.DeadlineExceeded, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
}
}
func TestListenerConnTCPContextDeadlineExceeded(t *testing.T) {
t.Parallel()
l, err := sockettest.Listen(0, nil)
if err != nil {
t.Fatalf("failed to open listener: %v", err)
}
defer l.Close()
// Accept a single connection.
var eg errgroup.Group
eg.Go(func() error {
c, err := l.Accept()
if err != nil {
return fmt.Errorf("failed to accept: %v", err)
}
defer c.Close()
// Context is canceled before recvfrom can take place.
ctx, cancel := context.WithCancel(context.Background())
cancel()
b := make([]byte, 1024)
_, _, err = c.(*sockettest.Conn).Conn.Recvfrom(ctx, b, 0)
return err
})
c, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatalf("failed to dial listener: %v", err)
}
defer c.Close()
// Client never sends data, so we wait until ctx cancel and errgroup return.
if diff := cmp.Diff(context.Canceled, eg.Wait(), cmpopts.EquateErrors()); diff != "" {
t.Fatalf("unexpected recvfrom error (-want +got):\n%s", diff)
}
}
func TestFileConn(t *testing.T) {
t.Parallel()
// Use raw system calls to set up the socket since we assume anything being
// passed into a FileConn is set up by another system, such as systemd's
// socket activation.
fd, err := unix.Socket(unix.AF_INET6, unix.SOCK_STREAM, 0)
if err != nil {
t.Fatalf("failed to open socket: %v", err)
}
// Bind to loopback, any available port.
sa := &unix.SockaddrInet6{Addr: [16]byte{15: 0x01}}
if err := unix.Bind(fd, sa); err != nil {
t.Fatalf("failed to bind: %v", err)
}
if err := unix.Listen(fd, unix.SOMAXCONN); err != nil {
t.Fatalf("failed to listen: %v", err)
}
// The socket should be ready, create a blocking file which is ready to be
// passed into FileConn via the FileListener helper.
f := os.NewFile(uintptr(fd), "tcpv6-listener")
defer f.Close()
l, err := sockettest.FileListener(f)
if err != nil {
t.Fatalf("failed to open file listener: %v", err)
}
defer l.Close()
// To exercise the listener, attempt to accept and then immediately close a
// single TCPv6 connection. Dial to the listener from the main goroutine and
// wait for everything to finish.
var eg errgroup.Group
eg.Go(func() error {
c, err := l.Accept()
if err != nil {
return fmt.Errorf("failed to accept: %v", err)
}
_ = c.Close()
return nil
})
c, err := net.Dial(l.Addr().Network(), l.Addr().String())
if err != nil {
t.Fatalf("failed to dial listener: %v", err)
}
_ = c.Close()
if err := eg.Wait(); err != nil {
t.Fatalf("failed to wait for listener goroutine: %v", err)
}
}
// Use our TCP net.Listener and net.Conn implementations backed by *socket.Conn
// and run compliance tests with nettest.TestConn.
//
// This nettest.MakePipe function is adapted from nettest's own tests:
// https://github.com/golang/net/blob/master/nettest/conntest_test.go
//
// Copyright 2016 The Go Authors. All rights reserved. Use of this source
// code is governed by a BSD-style license that can be found in the LICENSE
// file.
func makePipe(
listen func() (net.Listener, error),
dial func(addr net.Addr) (net.Conn, error),
) nettest.MakePipe {
return func() (c1, c2 net.Conn, stop func(), err error) {
ln, err := listen()
if err != nil {
return nil, nil, nil, err
}
// Start a connection between two endpoints.
var err1, err2 error
done := make(chan bool)
go func() {
c2, err2 = ln.Accept()
close(done)
}()
c1, err1 = dial(ln.Addr())
<-done
stop = func() {
if err1 == nil {
c1.Close()
}
if err2 == nil {
c2.Close()
}
ln.Close()
}
switch {
case err1 != nil:
stop()
return nil, nil, nil, err1
case err2 != nil:
stop()
return nil, nil, nil, err2
default:
return c1, c2, stop, nil
}
}
}
// Copied from x/net/nettest, pending acceptance of:
// https://go-review.googlesource.com/c/net/+/372815
type connTester func(t *testing.T, c1, c2 net.Conn)
func timeoutWrapper(t *testing.T, mp nettest.MakePipe, f connTester) {
t.Helper()
c1, c2, stop, err := mp()
if err != nil {
t.Fatalf("unable to make pipe: %v", err)
}
var once sync.Once
defer once.Do(func() { stop() })
timer := time.AfterFunc(time.Minute, func() {
once.Do(func() {
t.Error("test timed out; terminating pipe")
stop()
})
})
defer timer.Stop()
f(t, c1, c2)
}
// testCloseReadWrite tests that net.Conns which also implement the optional
// CloseRead and CloseWrite methods can be half-closed correctly.
func testCloseReadWrite(t *testing.T, c1, c2 net.Conn) {
// TODO(mdlayher): investigate why Mac/Windows errors are so different.
if runtime.GOOS != "linux" {
t.Skip("skipping, not supported on non-Linux platforms")
}
type closerConn interface {
net.Conn
CloseRead() error
CloseWrite() error
}
cc1, ok1 := c1.(closerConn)
cc2, ok2 := c2.(closerConn)
if !ok1 || !ok2 {
// Both c1 and c2 must implement closerConn to proceed.
return
}
var wg sync.WaitGroup
wg.Add(2)
defer wg.Wait()
go func() {
defer wg.Done()
// Writing succeeds at first but should result in a permanent "broken
// pipe" error after closing the write side of the net.Conn.
b := make([]byte, 64)
if err := chunkedCopy(cc1, bytes.NewReader(b)); err != nil {
t.Errorf("unexpected initial cc1.Write error: %v", err)
}
if err := cc1.CloseWrite(); err != nil {
t.Errorf("unexpected cc1.CloseWrite error: %v", err)
}
_, err := cc1.Write(b)
if nerr, ok := err.(net.Error); !ok || nerr.Timeout() {
t.Errorf("unexpected final cc1.Write error: %v", err)
}
}()
go func() {
defer wg.Done()
// Reading succeeds at first but should result in an EOF error after
// closing the read side of the net.Conn.
if err := chunkedCopy(io.Discard, cc2); err != nil {
t.Errorf("unexpected initial cc2.Read error: %v", err)
}
if err := cc2.CloseRead(); err != nil {
t.Errorf("unexpected cc2.CloseRead error: %v", err)
}
if _, err := cc2.Read(make([]byte, 64)); err != io.EOF {
t.Errorf("unexpected final cc2.Read error: %v", err)
}
}()
}
// chunkedCopy copies from r to w in fixed-width chunks to avoid
// causing a Write that exceeds the maximum packet size for packet-based
// connections like "unixpacket".
// We assume that the maximum packet size is at least 1024.
func chunkedCopy(w io.Writer, r io.Reader) error {
b := make([]byte, 1024)
_, err := io.CopyBuffer(struct{ io.Writer }{w}, struct{ io.Reader }{r}, b)
return err
}