247 lines
6.2 KiB
Go
247 lines
6.2 KiB
Go
//go:build linux
|
|
|
|
package main
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"os/signal"
|
|
"strings"
|
|
"time"
|
|
|
|
"omnisocketgo/cmd/internal/latencylog"
|
|
peerpkg "omnisocketgo/cmd/internal/peer"
|
|
"omnisocketgo/cmd/internal/protocol"
|
|
)
|
|
|
|
func runPlatform(cfg config, stdout, stderr io.Writer, now func() time.Time) error {
|
|
ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt)
|
|
defer stop()
|
|
|
|
client, closeLogger, err := dialKCPClient(cfg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer closeLogger()
|
|
defer client.Close()
|
|
|
|
if cfg.echo {
|
|
return runEchoMode(ctx, client, stderr)
|
|
}
|
|
return runPingMode(ctx, client, cfg, stdout, stderr, now)
|
|
}
|
|
|
|
func dialKCPClient(cfg config) (*peerpkg.KCPClient, func(), error) {
|
|
options := make([]peerpkg.Option, 0, 3)
|
|
closeLogger := func() {}
|
|
|
|
if cfg.latencyLog != "" {
|
|
logger, err := latencylog.NewJSONLLogger(cfg.latencyLog)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("create latency logger %s: %w", cfg.latencyLog, err)
|
|
}
|
|
options = append(options, peerpkg.WithLogger(logger))
|
|
closeLogger = func() {
|
|
_ = logger.Close()
|
|
}
|
|
}
|
|
if cfg.bindIP != "" {
|
|
options = append(options, peerpkg.WithBindIP(cfg.bindIP))
|
|
}
|
|
if cfg.bindDevice != "" {
|
|
options = append(options, peerpkg.WithBindDevice(cfg.bindDevice))
|
|
}
|
|
|
|
client, err := peerpkg.DialKCP(cfg.server, cfg.id, options...)
|
|
if err != nil {
|
|
closeLogger()
|
|
return nil, nil, fmt.Errorf("dial kcp server %s: %w", cfg.server, err)
|
|
}
|
|
return client, closeLogger, nil
|
|
}
|
|
|
|
func runPingMode(ctx context.Context, client *peerpkg.KCPClient, cfg config, stdout, stderr io.Writer, now func() time.Time) error {
|
|
if err := writePingHeader(stdout, cfg); err != nil {
|
|
return err
|
|
}
|
|
|
|
receiveCh := make(chan protocol.Message, 32)
|
|
receiveErrCh := make(chan error, 1)
|
|
go func() {
|
|
for {
|
|
msg, err := client.Receive()
|
|
if err != nil {
|
|
receiveErrCh <- err
|
|
return
|
|
}
|
|
receiveCh <- msg
|
|
}
|
|
}()
|
|
|
|
tracker := newPingTracker(cfg.timeout)
|
|
expiryTicker := time.NewTicker(expiryPollInterval(cfg.timeout))
|
|
defer expiryTicker.Stop()
|
|
|
|
var sendTicker *time.Ticker
|
|
if cfg.count == 0 || cfg.count > 1 {
|
|
sendTicker = time.NewTicker(cfg.interval)
|
|
defer sendTicker.Stop()
|
|
}
|
|
|
|
nextSeq := uint64(1)
|
|
if err := sendPing(client, tracker, cfg, nextSeq, now); err != nil {
|
|
return err
|
|
}
|
|
nextSeq++
|
|
|
|
stopSending := cfg.count == 1
|
|
receiveErrSeen := false
|
|
|
|
for {
|
|
if stopSending && tracker.pendingCount() == 0 {
|
|
break
|
|
}
|
|
|
|
var sendTick <-chan time.Time
|
|
if !stopSending && sendTicker != nil {
|
|
sendTick = sendTicker.C
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
stopSending = true
|
|
case <-expiryTicker.C:
|
|
for _, seq := range tracker.expire(now()) {
|
|
if err := writeTimeout(stdout, seq); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
case msg := <-receiveCh:
|
|
if err := handlePingMessage(tracker, msg, stdout, stderr, now); err != nil {
|
|
return err
|
|
}
|
|
case err := <-receiveErrCh:
|
|
receiveErrSeen = true
|
|
if ctx.Err() != nil && isExpectedCloseError(err) {
|
|
break
|
|
}
|
|
if stopSending && tracker.pendingCount() == 0 && isExpectedCloseError(err) {
|
|
break
|
|
}
|
|
return fmt.Errorf("receive reply: %w", err)
|
|
case <-sendTick:
|
|
if cfg.count > 0 && tracker.sent >= cfg.count {
|
|
stopSending = true
|
|
continue
|
|
}
|
|
if err := sendPing(client, tracker, cfg, nextSeq, now); err != nil {
|
|
return err
|
|
}
|
|
nextSeq++
|
|
if cfg.count > 0 && tracker.sent >= cfg.count {
|
|
stopSending = true
|
|
}
|
|
}
|
|
|
|
if receiveErrSeen && stopSending && tracker.pendingCount() == 0 {
|
|
break
|
|
}
|
|
}
|
|
|
|
for _, seq := range tracker.expire(now()) {
|
|
if err := writeTimeout(stdout, seq); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return writeSummary(stdout, cfg.to, tracker.summary())
|
|
}
|
|
|
|
func sendPing(client *peerpkg.KCPClient, tracker *pingTracker, cfg config, seq uint64, now func() time.Time) error {
|
|
sentAt := now()
|
|
payload, err := buildPingPayload(seq, sentAt.UnixNano(), cfg.size)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if err := client.SendText(cfg.to, string(payload)); err != nil {
|
|
return fmt.Errorf("send ping seq=%d: %w", seq, err)
|
|
}
|
|
tracker.markSent(seq, sentAt)
|
|
return nil
|
|
}
|
|
|
|
func handlePingMessage(tracker *pingTracker, msg protocol.Message, stdout, stderr io.Writer, now func() time.Time) error {
|
|
switch msg.Type {
|
|
case protocol.MessageTypeText:
|
|
payload, err := parsePingPayload(msg.Body)
|
|
if err != nil {
|
|
_, writeErr := fmt.Fprintf(stderr, "ignore non-ping text message from %s: %v\n", msg.From, err)
|
|
if writeErr != nil {
|
|
return writeErr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
result := tracker.observeReply(payload, now())
|
|
switch result.disposition {
|
|
case replyMatched:
|
|
return writeMatchedReply(stdout, payload.Seq, result.rtt)
|
|
case replyDuplicate:
|
|
_, err := fmt.Fprintf(stderr, "seq=%d duplicate or late reply ignored\n", payload.Seq)
|
|
return err
|
|
case replyUnexpected:
|
|
_, err := fmt.Fprintf(stderr, "seq=%d unexpected reply ignored\n", payload.Seq)
|
|
return err
|
|
default:
|
|
return nil
|
|
}
|
|
case protocol.MessageTypeError:
|
|
_, err := fmt.Fprintf(stderr, "server error: %s\n", strings.TrimSpace(string(msg.Body)))
|
|
return err
|
|
default:
|
|
_, err := fmt.Fprintf(stderr, "unexpected message type %s from %s ignored\n", msg.Type, msg.From)
|
|
return err
|
|
}
|
|
}
|
|
|
|
func runEchoMode(ctx context.Context, client *peerpkg.KCPClient, stderr io.Writer) error {
|
|
receiveErrCh := make(chan error, 1)
|
|
go func() {
|
|
receiveErrCh <- client.ReceiveLoop(func(msg protocol.Message) error {
|
|
switch msg.Type {
|
|
case protocol.MessageTypeText:
|
|
return client.SendText(msg.From, string(msg.Body))
|
|
case protocol.MessageTypeError:
|
|
_, err := fmt.Fprintf(stderr, "server error: %s\n", strings.TrimSpace(string(msg.Body)))
|
|
return err
|
|
default:
|
|
_, err := fmt.Fprintf(stderr, "unexpected message type %s from %s ignored\n", msg.Type, msg.From)
|
|
return err
|
|
}
|
|
})
|
|
}()
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case err := <-receiveErrCh:
|
|
if err == nil || (ctx.Err() != nil && isExpectedCloseError(err)) {
|
|
return nil
|
|
}
|
|
return fmt.Errorf("echo receive loop: %w", err)
|
|
}
|
|
}
|
|
|
|
func isExpectedCloseError(err error) bool {
|
|
if err == nil {
|
|
return true
|
|
}
|
|
message := err.Error()
|
|
return errors.Is(err, context.Canceled) ||
|
|
strings.Contains(message, "closed") ||
|
|
strings.Contains(message, "broken pipe") ||
|
|
strings.Contains(message, "io: read/write on closed pipe")
|
|
}
|