From 5b231141a75a2b3ee455a98f32eed4d7a75796b9 Mon Sep 17 00:00:00 2001 From: Mock Date: Fri, 27 Mar 2026 01:29:21 +0800 Subject: [PATCH] feat: udpping --- cmd/udpping/main.go | 389 ++++++++++++++++++++++++++++++++++ cmd/udpping/main_test.go | 203 ++++++++++++++++++ cmd/udpping/platform_linux.go | 243 +++++++++++++++++++++ cmd/udpping/platform_other.go | 14 ++ 4 files changed, 849 insertions(+) create mode 100644 cmd/udpping/main.go create mode 100644 cmd/udpping/main_test.go create mode 100644 cmd/udpping/platform_linux.go create mode 100644 cmd/udpping/platform_other.go diff --git a/cmd/udpping/main.go b/cmd/udpping/main.go new file mode 100644 index 0000000..2b416fc --- /dev/null +++ b/cmd/udpping/main.go @@ -0,0 +1,389 @@ +package main + +import ( + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "math" + "os" + "sort" + "strings" + "time" +) + +const ( + defaultPeerID = "pinger" + defaultServer = "127.0.0.1:9001" + defaultCount = 100 + defaultInterval = 100 * time.Millisecond + defaultSize = 64 + defaultTimeout = 3 * time.Second + minExpiryPoll = 10 * time.Millisecond + maxExpiryPoll = 100 * time.Millisecond +) + +type config struct { + id string + server string + to string + echo bool + count int + interval time.Duration + size int + timeout time.Duration + bindIP string + latencyLog string +} + +type pingPayload struct { + Seq uint64 `json:"seq"` + TSUnixNano int64 `json:"ts_ns"` + Pad string `json:"pad"` +} + +type pendingPing struct { + deadline time.Time +} + +type replyDisposition int + +const ( + replyMatched replyDisposition = iota + replyDuplicate + replyUnexpected +) + +type replyResult struct { + disposition replyDisposition + rtt time.Duration +} + +type pingTracker struct { + timeout time.Duration + sent int + duplicates int + pending map[uint64]pendingPing + seen map[uint64]struct{} + samples []time.Duration +} + +type rttSummary struct { + Sent int + Received int + Duplicates int + LossPct float64 + Min time.Duration + Avg time.Duration + Max time.Duration + P50 time.Duration + P95 time.Duration + P99 time.Duration + StdDev time.Duration + HasSamples bool +} + +func main() { + if err := runMain(os.Args[1:], os.Stdout, os.Stderr, time.Now); err != nil { + if errors.Is(err, flag.ErrHelp) { + return + } + _, _ = fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func runMain(args []string, stdout, stderr io.Writer, now func() time.Time) error { + cfg, err := parseConfig(args, stderr) + if err != nil { + return err + } + + return runPlatform(cfg, stdout, stderr, now) +} + +func parseConfig(args []string, stderr io.Writer) (config, error) { + cfg := config{} + + flags := flag.NewFlagSet("udpping", flag.ContinueOnError) + flags.SetOutput(stderr) + flags.StringVar(&cfg.id, "id", defaultPeerID, "local peer identity") + flags.StringVar(&cfg.server, "server", defaultServer, "UDP server address") + flags.StringVar(&cfg.to, "to", "", "target peer identity in ping mode") + flags.BoolVar(&cfg.echo, "echo", false, "echo back every received text message") + flags.IntVar(&cfg.count, "count", defaultCount, "number of pings to send; 0 means run until interrupted") + flags.DurationVar(&cfg.interval, "interval", defaultInterval, "delay between ping sends") + flags.IntVar(&cfg.size, "size", defaultSize, "application payload size in bytes") + flags.DurationVar(&cfg.timeout, "timeout", defaultTimeout, "per-ping timeout") + flags.StringVar(&cfg.bindIP, "bind-ip", "", "optional local source IP used when dialing the server") + flags.StringVar(&cfg.latencyLog, "latency-log", "", "optional JSONL file path for latency timestamp logs") + + if err := flags.Parse(args); err != nil { + return config{}, err + } + if flags.NArg() > 0 { + return config{}, fmt.Errorf("unexpected positional arguments: %s", strings.Join(flags.Args(), " ")) + } + + cfg.id = strings.TrimSpace(cfg.id) + cfg.server = strings.TrimSpace(cfg.server) + cfg.to = strings.TrimSpace(cfg.to) + cfg.bindIP = strings.TrimSpace(cfg.bindIP) + cfg.latencyLog = strings.TrimSpace(cfg.latencyLog) + + if err := cfg.validate(); err != nil { + return config{}, err + } + return cfg, nil +} + +func (c config) validate() error { + if c.id == "" { + return fmt.Errorf("flag -id is required") + } + if c.server == "" { + return fmt.Errorf("flag -server is required") + } + if !c.echo && c.to == "" { + return fmt.Errorf("flag -to is required unless -echo is set") + } + if c.count < 0 { + return fmt.Errorf("flag -count must be greater than or equal to zero") + } + if c.interval <= 0 { + return fmt.Errorf("flag -interval must be greater than zero") + } + if c.size <= 0 { + return fmt.Errorf("flag -size must be greater than zero") + } + if c.timeout <= 0 { + return fmt.Errorf("flag -timeout must be greater than zero") + } + return nil +} + +func buildPingPayload(seq uint64, tsUnixNano int64, size int) ([]byte, error) { + payload := pingPayload{ + Seq: seq, + TSUnixNano: tsUnixNano, + Pad: "", + } + + body, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("encode ping payload: %w", err) + } + if len(body) > size { + return nil, fmt.Errorf("requested payload size %d is too small; minimum is %d", size, len(body)) + } + + payload.Pad = strings.Repeat("A", size-len(body)) + body, err = json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("encode padded ping payload: %w", err) + } + if len(body) != size { + return nil, fmt.Errorf("encode padded ping payload: got %d bytes, want %d", len(body), size) + } + return body, nil +} + +func parsePingPayload(body []byte) (pingPayload, error) { + var payload pingPayload + if err := json.Unmarshal(body, &payload); err != nil { + return pingPayload{}, fmt.Errorf("decode ping payload: %w", err) + } + if payload.Seq == 0 { + return pingPayload{}, fmt.Errorf("decode ping payload: seq must be greater than zero") + } + if payload.TSUnixNano <= 0 { + return pingPayload{}, fmt.Errorf("decode ping payload: ts_ns must be greater than zero") + } + return payload, nil +} + +func newPingTracker(timeout time.Duration) *pingTracker { + return &pingTracker{ + timeout: timeout, + pending: make(map[uint64]pendingPing), + seen: make(map[uint64]struct{}), + } +} + +func (t *pingTracker) markSent(seq uint64, sentAt time.Time) { + t.sent++ + t.pending[seq] = pendingPing{deadline: sentAt.Add(t.timeout)} + t.seen[seq] = struct{}{} +} + +func (t *pingTracker) observeReply(payload pingPayload, receivedAt time.Time) replyResult { + if _, ok := t.seen[payload.Seq]; !ok { + return replyResult{disposition: replyUnexpected} + } + + if _, ok := t.pending[payload.Seq]; !ok { + t.duplicates++ + return replyResult{disposition: replyDuplicate} + } + + delete(t.pending, payload.Seq) + rtt := receivedAt.Sub(time.Unix(0, payload.TSUnixNano)) + if rtt < 0 { + rtt = 0 + } + t.samples = append(t.samples, rtt) + return replyResult{ + disposition: replyMatched, + rtt: rtt, + } +} + +func (t *pingTracker) expire(now time.Time) []uint64 { + expired := make([]uint64, 0) + for seq, pending := range t.pending { + if !pending.deadline.After(now) { + expired = append(expired, seq) + delete(t.pending, seq) + } + } + sort.Slice(expired, func(i, j int) bool { + return expired[i] < expired[j] + }) + return expired +} + +func (t *pingTracker) pendingCount() int { + return len(t.pending) +} + +func (t *pingTracker) summary() rttSummary { + return calculateRTTSummary(t.samples, t.sent, t.duplicates) +} + +func calculateRTTSummary(samples []time.Duration, sent, duplicates int) rttSummary { + summary := rttSummary{ + Sent: sent, + Received: len(samples), + Duplicates: duplicates, + } + if sent > 0 { + summary.LossPct = float64(sent-len(samples)) * 100 / float64(sent) + } + if len(samples) == 0 { + return summary + } + + sorted := append([]time.Duration(nil), samples...) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i] < sorted[j] + }) + + var sum float64 + for _, sample := range sorted { + sum += float64(sample) + } + avg := sum / float64(len(sorted)) + + var variance float64 + for _, sample := range sorted { + delta := float64(sample) - avg + variance += delta * delta + } + variance /= float64(len(sorted)) + + summary.Min = sorted[0] + summary.Avg = time.Duration(math.Round(avg)) + summary.Max = sorted[len(sorted)-1] + summary.P50 = percentileDuration(sorted, 0.50) + summary.P95 = percentileDuration(sorted, 0.95) + summary.P99 = percentileDuration(sorted, 0.99) + summary.StdDev = time.Duration(math.Round(math.Sqrt(variance))) + summary.HasSamples = true + return summary +} + +func percentileDuration(sorted []time.Duration, percentile float64) time.Duration { + if len(sorted) == 0 { + return 0 + } + if percentile <= 0 { + return sorted[0] + } + if percentile >= 1 { + return sorted[len(sorted)-1] + } + + index := int(math.Ceil(percentile*float64(len(sorted)))) - 1 + if index < 0 { + index = 0 + } + if index >= len(sorted) { + index = len(sorted) - 1 + } + return sorted[index] +} + +func formatRTT(duration time.Duration) string { + return fmt.Sprintf("%.2fms", float64(duration)/float64(time.Millisecond)) +} + +func writePingHeader(w io.Writer, cfg config) error { + _, err := fmt.Fprintf(w, "UDP PING %s via %s (payload=%d bytes, UDP)\n", cfg.to, cfg.server, cfg.size) + return err +} + +func writeMatchedReply(w io.Writer, seq uint64, rtt time.Duration) error { + _, err := fmt.Fprintf(w, "seq=%d rtt=%s\n", seq, formatRTT(rtt)) + return err +} + +func writeTimeout(w io.Writer, seq uint64) error { + _, err := fmt.Fprintf(w, "seq=%d timeout\n", seq) + return err +} + +func writeSummary(w io.Writer, target string, summary rttSummary) error { + if _, err := fmt.Fprintf(w, "--- %s udp ping statistics ---\n", target); err != nil { + return err + } + if _, err := fmt.Fprintf( + w, + "%d packets transmitted, %d received, %d duplicates, %.2f%% packet loss\n", + summary.Sent, + summary.Received, + summary.Duplicates, + summary.LossPct, + ); err != nil { + return err + } + + if !summary.HasSamples { + _, err := fmt.Fprintln(w, "rtt min/avg/max/p50/p95/p99 = n/a/n/a/n/a/n/a/n/a/n/a, stddev=n/a") + return err + } + + _, err := fmt.Fprintf( + w, + "rtt min/avg/max/p50/p95/p99 = %s/%s/%s/%s/%s/%s, stddev=%s\n", + formatRTT(summary.Min), + formatRTT(summary.Avg), + formatRTT(summary.Max), + formatRTT(summary.P50), + formatRTT(summary.P95), + formatRTT(summary.P99), + formatRTT(summary.StdDev), + ) + return err +} + +func expiryPollInterval(timeout time.Duration) time.Duration { + interval := timeout / 4 + if interval < minExpiryPoll { + return minExpiryPoll + } + if interval > maxExpiryPoll { + return maxExpiryPoll + } + return interval +} diff --git a/cmd/udpping/main_test.go b/cmd/udpping/main_test.go new file mode 100644 index 0000000..adb1915 --- /dev/null +++ b/cmd/udpping/main_test.go @@ -0,0 +1,203 @@ +package main + +import ( + "bytes" + "strings" + "testing" + "time" +) + +func TestParseConfigDefaults(t *testing.T) { + cfg, err := parseConfig([]string{"-to", "peer-b"}, ioDiscard{}) + if err != nil { + t.Fatalf("parseConfig() error = %v", err) + } + + if cfg.id != defaultPeerID { + t.Fatalf("id = %q, want %q", cfg.id, defaultPeerID) + } + if cfg.server != defaultServer { + t.Fatalf("server = %q, want %q", cfg.server, defaultServer) + } + if cfg.count != defaultCount { + t.Fatalf("count = %d, want %d", cfg.count, defaultCount) + } + if cfg.interval != defaultInterval { + t.Fatalf("interval = %s, want %s", cfg.interval, defaultInterval) + } + if cfg.size != defaultSize { + t.Fatalf("size = %d, want %d", cfg.size, defaultSize) + } + if cfg.timeout != defaultTimeout { + t.Fatalf("timeout = %s, want %s", cfg.timeout, defaultTimeout) + } +} + +func TestParseConfigRequiresTargetInPingMode(t *testing.T) { + _, err := parseConfig([]string{"-echo=false"}, ioDiscard{}) + if err == nil || !strings.Contains(err.Error(), "flag -to is required") { + t.Fatalf("parseConfig() error = %v, want missing -to error", err) + } +} + +func TestParseConfigAllowsEchoWithoutTarget(t *testing.T) { + cfg, err := parseConfig([]string{"-echo"}, ioDiscard{}) + if err != nil { + t.Fatalf("parseConfig() error = %v", err) + } + if !cfg.echo { + t.Fatal("echo = false, want true") + } +} + +func TestParseConfigRejectsBindDeviceFlag(t *testing.T) { + _, err := parseConfig([]string{"-to", "peer-b", "-bind-device", "wwan0"}, ioDiscard{}) + if err == nil || !strings.Contains(err.Error(), "flag provided but not defined") { + t.Fatalf("parseConfig() error = %v, want unknown flag error", err) + } +} + +func TestBuildPingPayloadUsesExactSize(t *testing.T) { + body, err := buildPingPayload(7, 123456789, 96) + if err != nil { + t.Fatalf("buildPingPayload() error = %v", err) + } + if len(body) != 96 { + t.Fatalf("len(body) = %d, want 96", len(body)) + } + + payload, err := parsePingPayload(body) + if err != nil { + t.Fatalf("parsePingPayload() error = %v", err) + } + if payload.Seq != 7 { + t.Fatalf("seq = %d, want 7", payload.Seq) + } + if payload.TSUnixNano != 123456789 { + t.Fatalf("ts_ns = %d, want 123456789", payload.TSUnixNano) + } +} + +func TestBuildPingPayloadRejectsTooSmallSize(t *testing.T) { + _, err := buildPingPayload(1, 123456789, 8) + if err == nil || !strings.Contains(err.Error(), "too small") { + t.Fatalf("buildPingPayload() error = %v, want size too small error", err) + } +} + +func TestParsePingPayloadRejectsInvalidJSON(t *testing.T) { + _, err := parsePingPayload([]byte("not-json")) + if err == nil || !strings.Contains(err.Error(), "decode ping payload") { + t.Fatalf("parsePingPayload() error = %v, want decode error", err) + } +} + +func TestPingTrackerHandlesMatchedDuplicateAndTimeout(t *testing.T) { + tracker := newPingTracker(50 * time.Millisecond) + sentAt := time.Unix(0, 100) + tracker.markSent(1, sentAt) + + match := tracker.observeReply(pingPayload{Seq: 1, TSUnixNano: sentAt.UnixNano()}, sentAt.Add(12*time.Millisecond)) + if match.disposition != replyMatched { + t.Fatalf("first disposition = %v, want matched", match.disposition) + } + if match.rtt != 12*time.Millisecond { + t.Fatalf("first rtt = %s, want 12ms", match.rtt) + } + + duplicate := tracker.observeReply(pingPayload{Seq: 1, TSUnixNano: sentAt.UnixNano()}, sentAt.Add(20*time.Millisecond)) + if duplicate.disposition != replyDuplicate { + t.Fatalf("second disposition = %v, want duplicate", duplicate.disposition) + } + + tracker.markSent(2, sentAt) + expired := tracker.expire(sentAt.Add(60 * time.Millisecond)) + if len(expired) != 1 || expired[0] != 2 { + t.Fatalf("expired = %v, want [2]", expired) + } + + late := tracker.observeReply(pingPayload{Seq: 2, TSUnixNano: sentAt.UnixNano()}, sentAt.Add(70*time.Millisecond)) + if late.disposition != replyDuplicate { + t.Fatalf("late disposition = %v, want duplicate", late.disposition) + } + + unexpected := tracker.observeReply(pingPayload{Seq: 99, TSUnixNano: sentAt.UnixNano()}, sentAt.Add(80*time.Millisecond)) + if unexpected.disposition != replyUnexpected { + t.Fatalf("unexpected disposition = %v, want unexpected", unexpected.disposition) + } +} + +func TestCalculateRTTSummary(t *testing.T) { + summary := calculateRTTSummary( + []time.Duration{ + 10 * time.Millisecond, + 20 * time.Millisecond, + 30 * time.Millisecond, + 40 * time.Millisecond, + 50 * time.Millisecond, + }, + 6, + 2, + ) + + if summary.Sent != 6 { + t.Fatalf("Sent = %d, want 6", summary.Sent) + } + if summary.Received != 5 { + t.Fatalf("Received = %d, want 5", summary.Received) + } + if summary.Duplicates != 2 { + t.Fatalf("Duplicates = %d, want 2", summary.Duplicates) + } + if summary.LossPct != (float64(1) * 100 / 6) { + t.Fatalf("LossPct = %f, want %f", summary.LossPct, float64(1)*100/6) + } + if summary.Min != 10*time.Millisecond { + t.Fatalf("Min = %s, want 10ms", summary.Min) + } + if summary.Avg != 30*time.Millisecond { + t.Fatalf("Avg = %s, want 30ms", summary.Avg) + } + if summary.Max != 50*time.Millisecond { + t.Fatalf("Max = %s, want 50ms", summary.Max) + } + if summary.P50 != 30*time.Millisecond { + t.Fatalf("P50 = %s, want 30ms", summary.P50) + } + if summary.P95 != 50*time.Millisecond { + t.Fatalf("P95 = %s, want 50ms", summary.P95) + } + if summary.P99 != 50*time.Millisecond { + t.Fatalf("P99 = %s, want 50ms", summary.P99) + } + if summary.StdDev == 0 { + t.Fatal("StdDev = 0, want non-zero") + } +} + +func TestWriteSummaryUsesNAWithoutSamples(t *testing.T) { + var buf bytes.Buffer + err := writeSummary(&buf, "host", rttSummary{ + Sent: 3, + Received: 0, + Duplicates: 1, + LossPct: 100, + }) + if err != nil { + t.Fatalf("writeSummary() error = %v", err) + } + + out := buf.String() + if !strings.Contains(out, "3 packets transmitted, 0 received, 1 duplicates, 100.00% packet loss") { + t.Fatalf("summary output missing counters: %q", out) + } + if !strings.Contains(out, "n/a/n/a/n/a/n/a/n/a/n/a") { + t.Fatalf("summary output missing n/a metrics: %q", out) + } +} + +type ioDiscard struct{} + +func (ioDiscard) Write(p []byte) (int, error) { + return len(p), nil +} diff --git a/cmd/udpping/platform_linux.go b/cmd/udpping/platform_linux.go new file mode 100644 index 0000000..7780b69 --- /dev/null +++ b/cmd/udpping/platform_linux.go @@ -0,0 +1,243 @@ +//go:build linux + +package main + +import ( + "context" + "errors" + "fmt" + "io" + "os" + "os/signal" + "strings" + "time" + + "omnisocketgo/cmd/internal/latencylog" + peerpkg "omnisocketgo/cmd/internal/peer" + "omnisocketgo/cmd/internal/protocol" +) + +func runPlatform(cfg config, stdout, stderr io.Writer, now func() time.Time) error { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt) + defer stop() + + client, closeLogger, err := dialUDPClient(cfg) + if err != nil { + return err + } + defer closeLogger() + defer client.Close() + + if cfg.echo { + return runEchoMode(ctx, client, stderr) + } + return runPingMode(ctx, client, cfg, stdout, stderr, now) +} + +func dialUDPClient(cfg config) (*peerpkg.UDPClient, func(), error) { + options := make([]peerpkg.Option, 0, 2) + closeLogger := func() {} + + if cfg.latencyLog != "" { + logger, err := latencylog.NewJSONLLogger(cfg.latencyLog) + if err != nil { + return nil, nil, fmt.Errorf("create latency logger %s: %w", cfg.latencyLog, err) + } + options = append(options, peerpkg.WithLogger(logger)) + closeLogger = func() { + _ = logger.Close() + } + } + if cfg.bindIP != "" { + options = append(options, peerpkg.WithBindIP(cfg.bindIP)) + } + + client, err := peerpkg.DialUDP(cfg.server, cfg.id, options...) + if err != nil { + closeLogger() + return nil, nil, fmt.Errorf("dial udp server %s: %w", cfg.server, err) + } + return client, closeLogger, nil +} + +func runPingMode(ctx context.Context, client *peerpkg.UDPClient, cfg config, stdout, stderr io.Writer, now func() time.Time) error { + if err := writePingHeader(stdout, cfg); err != nil { + return err + } + + receiveCh := make(chan protocol.Message, 32) + receiveErrCh := make(chan error, 1) + go func() { + for { + msg, err := client.Receive() + if err != nil { + receiveErrCh <- err + return + } + receiveCh <- msg + } + }() + + tracker := newPingTracker(cfg.timeout) + expiryTicker := time.NewTicker(expiryPollInterval(cfg.timeout)) + defer expiryTicker.Stop() + + var sendTicker *time.Ticker + if cfg.count == 0 || cfg.count > 1 { + sendTicker = time.NewTicker(cfg.interval) + defer sendTicker.Stop() + } + + nextSeq := uint64(1) + if err := sendPing(client, tracker, cfg, nextSeq, now); err != nil { + return err + } + nextSeq++ + + stopSending := cfg.count == 1 + receiveErrSeen := false + + for { + if stopSending && tracker.pendingCount() == 0 { + break + } + + var sendTick <-chan time.Time + if !stopSending && sendTicker != nil { + sendTick = sendTicker.C + } + + select { + case <-ctx.Done(): + stopSending = true + case <-expiryTicker.C: + for _, seq := range tracker.expire(now()) { + if err := writeTimeout(stdout, seq); err != nil { + return err + } + } + case msg := <-receiveCh: + if err := handlePingMessage(tracker, msg, stdout, stderr, now); err != nil { + return err + } + case err := <-receiveErrCh: + receiveErrSeen = true + if ctx.Err() != nil && isExpectedCloseError(err) { + break + } + if stopSending && tracker.pendingCount() == 0 && isExpectedCloseError(err) { + break + } + return fmt.Errorf("receive reply: %w", err) + case <-sendTick: + if cfg.count > 0 && tracker.sent >= cfg.count { + stopSending = true + continue + } + if err := sendPing(client, tracker, cfg, nextSeq, now); err != nil { + return err + } + nextSeq++ + if cfg.count > 0 && tracker.sent >= cfg.count { + stopSending = true + } + } + + if receiveErrSeen && stopSending && tracker.pendingCount() == 0 { + break + } + } + + for _, seq := range tracker.expire(now()) { + if err := writeTimeout(stdout, seq); err != nil { + return err + } + } + return writeSummary(stdout, cfg.to, tracker.summary()) +} + +func sendPing(client *peerpkg.UDPClient, tracker *pingTracker, cfg config, seq uint64, now func() time.Time) error { + sentAt := now() + payload, err := buildPingPayload(seq, sentAt.UnixNano(), cfg.size) + if err != nil { + return err + } + if err := client.SendText(cfg.to, string(payload)); err != nil { + return fmt.Errorf("send ping seq=%d: %w", seq, err) + } + tracker.markSent(seq, sentAt) + return nil +} + +func handlePingMessage(tracker *pingTracker, msg protocol.Message, stdout, stderr io.Writer, now func() time.Time) error { + switch msg.Type { + case protocol.MessageTypeText: + payload, err := parsePingPayload(msg.Body) + if err != nil { + _, writeErr := fmt.Fprintf(stderr, "ignore non-ping text message from %s: %v\n", msg.From, err) + if writeErr != nil { + return writeErr + } + return nil + } + + result := tracker.observeReply(payload, now()) + switch result.disposition { + case replyMatched: + return writeMatchedReply(stdout, payload.Seq, result.rtt) + case replyDuplicate: + _, err := fmt.Fprintf(stderr, "seq=%d duplicate or late reply ignored\n", payload.Seq) + return err + case replyUnexpected: + _, err := fmt.Fprintf(stderr, "seq=%d unexpected reply ignored\n", payload.Seq) + return err + default: + return nil + } + case protocol.MessageTypeError: + _, err := fmt.Fprintf(stderr, "server error: %s\n", strings.TrimSpace(string(msg.Body))) + return err + default: + _, err := fmt.Fprintf(stderr, "unexpected message type %s from %s ignored\n", msg.Type, msg.From) + return err + } +} + +func runEchoMode(ctx context.Context, client *peerpkg.UDPClient, stderr io.Writer) error { + receiveErrCh := make(chan error, 1) + go func() { + receiveErrCh <- client.ReceiveLoop(func(msg protocol.Message) error { + switch msg.Type { + case protocol.MessageTypeText: + return client.SendText(msg.From, string(msg.Body)) + case protocol.MessageTypeError: + _, err := fmt.Fprintf(stderr, "server error: %s\n", strings.TrimSpace(string(msg.Body))) + return err + default: + _, err := fmt.Fprintf(stderr, "unexpected message type %s from %s ignored\n", msg.Type, msg.From) + return err + } + }) + }() + + select { + case <-ctx.Done(): + return nil + case err := <-receiveErrCh: + if err == nil || (ctx.Err() != nil && isExpectedCloseError(err)) { + return nil + } + return fmt.Errorf("echo receive loop: %w", err) + } +} + +func isExpectedCloseError(err error) bool { + if err == nil { + return true + } + message := err.Error() + return errors.Is(err, context.Canceled) || + strings.Contains(message, "closed") || + strings.Contains(message, "broken pipe") || + strings.Contains(message, "io: read/write on closed pipe") +} diff --git a/cmd/udpping/platform_other.go b/cmd/udpping/platform_other.go new file mode 100644 index 0000000..d13b792 --- /dev/null +++ b/cmd/udpping/platform_other.go @@ -0,0 +1,14 @@ +//go:build !linux + +package main + +import ( + "fmt" + "io" + "runtime" + "time" +) + +func runPlatform(cfg config, stdout, stderr io.Writer, now func() time.Time) error { + return fmt.Errorf("udpping is only supported on linux; current GOOS=%s", runtime.GOOS) +}