Compare commits

...

9 Commits

8 changed files with 263 additions and 62 deletions

2
.gitignore vendored
View File

@@ -18,3 +18,5 @@ c/bin
/python/omnisocket.egg-info /python/omnisocket.egg-info
*.so* *.so*
/.venv

View File

@@ -26,6 +26,8 @@
#define WORKER_CONTROL_FD 3 #define WORKER_CONTROL_FD 3
#define WORKER_TELEMETRY_FD 4 #define WORKER_TELEMETRY_FD 4
#define WORKER_CONTROL_FD_ENV "OMNI_WORKER_CONTROL_FD"
#define WORKER_TELEMETRY_FD_ENV "OMNI_WORKER_TELEMETRY_FD"
#define NUM_BUFFERS 4 #define NUM_BUFFERS 4
#define CLEAR(x) memset(&(x), 0, sizeof(x)) #define CLEAR(x) memset(&(x), 0, sizeof(x))
@@ -238,11 +240,15 @@ static void telemetry_write_kcp_metrics(runtime_state_t *runtime, const kcp_conn
static int load_worker_config(worker_config_t *cfg) { static int load_worker_config(worker_config_t *cfg) {
const char *server_addr = getenv(VIDEO_SERVER_ADDR_ENV); const char *server_addr = getenv(VIDEO_SERVER_ADDR_ENV);
const char *relay_addr = env_or_default(VIDEO_RELAY_ADDR_ENV, "");
if (cfg == NULL) { if (cfg == NULL) {
errno = EINVAL; errno = EINVAL;
return -1; return -1;
} }
if ((server_addr == NULL || server_addr[0] == '\0') && relay_addr[0] != '\0') {
server_addr = relay_addr;
}
if (server_addr == NULL || server_addr[0] == '\0') { if (server_addr == NULL || server_addr[0] == '\0') {
fprintf(stderr, "%s is required\n", VIDEO_SERVER_ADDR_ENV); fprintf(stderr, "%s is required\n", VIDEO_SERVER_ADDR_ENV);
errno = EINVAL; errno = EINVAL;
@@ -251,7 +257,7 @@ static int load_worker_config(worker_config_t *cfg) {
CLEAR(*cfg); CLEAR(*cfg);
snprintf(cfg->server_addr, sizeof(cfg->server_addr), "%s", server_addr); snprintf(cfg->server_addr, sizeof(cfg->server_addr), "%s", server_addr);
snprintf(cfg->relay_addr, sizeof(cfg->relay_addr), "%s", env_or_default(VIDEO_RELAY_ADDR_ENV, "")); snprintf(cfg->relay_addr, sizeof(cfg->relay_addr), "%s", relay_addr);
snprintf(cfg->bind_ip, sizeof(cfg->bind_ip), "%s", env_or_default(VIDEO_BIND_IP_ENV, "")); snprintf(cfg->bind_ip, sizeof(cfg->bind_ip), "%s", env_or_default(VIDEO_BIND_IP_ENV, ""));
snprintf(cfg->bind_device, sizeof(cfg->bind_device), "%s", env_or_default(VIDEO_BIND_DEVICE_ENV, "")); snprintf(cfg->bind_device, sizeof(cfg->bind_device), "%s", env_or_default(VIDEO_BIND_DEVICE_ENV, ""));
snprintf(cfg->peer_id, sizeof(cfg->peer_id), "%s", env_or_default(VIDEO_PEER_ID_ENV, "peer-b-video")); snprintf(cfg->peer_id, sizeof(cfg->peer_id), "%s", env_or_default(VIDEO_PEER_ID_ENV, "peer-b-video"));
@@ -279,6 +285,7 @@ static int open_v4l2_device(const char *device) {
static int init_v4l2_device(int fd, const worker_config_t *cfg) { static int init_v4l2_device(int fd, const worker_config_t *cfg) {
struct v4l2_format fmt; struct v4l2_format fmt;
struct v4l2_streamparm parm;
CLEAR(fmt); CLEAR(fmt);
fmt.type = V4L2_BUF_TYPE_VIDEO_CAPTURE; fmt.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
@@ -291,6 +298,18 @@ static int init_v4l2_device(int fd, const worker_config_t *cfg) {
perror("VIDIOC_S_FMT"); perror("VIDIOC_S_FMT");
return -1; return -1;
} }
CLEAR(parm);
parm.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
if (cfg->initial_fps > 0 && ioctl(fd, VIDIOC_G_PARM, &parm) == 0) {
if ((parm.parm.capture.capability & V4L2_CAP_TIMEPERFRAME) != 0U) {
parm.parm.capture.timeperframe.numerator = 1U;
parm.parm.capture.timeperframe.denominator = (unsigned int) cfg->initial_fps;
if (ioctl(fd, VIDIOC_S_PARM, &parm) < 0) {
perror("VIDIOC_S_PARM");
}
}
}
return 0; return 0;
} }
@@ -353,6 +372,58 @@ static int queue_all_buffers(int fd, int num_buffers) {
return 0; return 0;
} }
static int dequeue_latest_buffer(int fd, struct v4l2_buffer *latest_buf) {
struct v4l2_buffer latest_local;
bool have_latest = false;
if (latest_buf == NULL) {
errno = EINVAL;
return -1;
}
for (;;) {
struct v4l2_buffer current;
int dq_errno;
CLEAR(current);
current.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
current.memory = V4L2_MEMORY_MMAP;
if (ioctl(fd, VIDIOC_DQBUF, &current) < 0) {
dq_errno = errno;
if (dq_errno == EINTR) {
continue;
}
if (dq_errno == EAGAIN) {
if (!have_latest) {
errno = EAGAIN;
return 1;
}
*latest_buf = latest_local;
return 0;
}
if (have_latest && ioctl(fd, VIDIOC_QBUF, &latest_local) < 0) {
perror("VIDIOC_QBUF");
}
errno = dq_errno;
return -1;
}
if (have_latest && ioctl(fd, VIDIOC_QBUF, &latest_local) < 0) {
int q_errno = errno;
perror("VIDIOC_QBUF");
if (ioctl(fd, VIDIOC_QBUF, &current) < 0) {
perror("VIDIOC_QBUF");
}
errno = q_errno;
return -1;
}
latest_local = current;
have_latest = true;
}
}
static AVCodecContext *create_mjpeg_decoder(const worker_config_t *cfg) { static AVCodecContext *create_mjpeg_decoder(const worker_config_t *cfg) {
const AVCodec *decoder = avcodec_find_decoder(AV_CODEC_ID_MJPEG); const AVCodec *decoder = avcodec_find_decoder(AV_CODEC_ID_MJPEG);
AVCodecContext *ctx; AVCodecContext *ctx;
@@ -607,6 +678,8 @@ int main(void) {
Buffer *buffers = NULL; Buffer *buffers = NULL;
int num_buffers = 0; int num_buffers = 0;
int camera_fd = -1; int camera_fd = -1;
int control_fd = env_as_int(WORKER_CONTROL_FD_ENV, WORKER_CONTROL_FD);
int telemetry_fd = env_as_int(WORKER_TELEMETRY_FD_ENV, WORKER_TELEMETRY_FD);
enum v4l2_buf_type stream_type = V4L2_BUF_TYPE_VIDEO_CAPTURE; enum v4l2_buf_type stream_type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
AVCodecContext *decoder = NULL; AVCodecContext *decoder = NULL;
AVCodecContext *encoder = NULL; AVCodecContext *encoder = NULL;
@@ -636,8 +709,8 @@ int main(void) {
runtime.jpeg_quality_qscale = cfg.initial_qscale; runtime.jpeg_quality_qscale = cfg.initial_qscale;
runtime.max_frame_bytes = cfg.initial_max_frame_bytes; runtime.max_frame_bytes = cfg.initial_max_frame_bytes;
control_stream = fdopen(WORKER_CONTROL_FD, "r"); control_stream = fdopen(control_fd, "r");
telemetry_stream = fdopen(WORKER_TELEMETRY_FD, "w"); telemetry_stream = fdopen(telemetry_fd, "w");
if (control_stream == NULL || telemetry_stream == NULL) { if (control_stream == NULL || telemetry_stream == NULL) {
perror("fdopen worker control/telemetry"); perror("fdopen worker control/telemetry");
goto cleanup; goto cleanup;
@@ -707,10 +780,6 @@ int main(void) {
if (fps < 1) { if (fps < 1) {
fps = 1; fps = 1;
} }
if (next_deadline_ms > now_ms) {
usleep((useconds_t) ((next_deadline_ms - now_ms) * 1000.0));
}
next_deadline_ms = monotonic_ms() + (1000.0 / (double) fps);
FD_ZERO(&fds); FD_ZERO(&fds);
FD_SET(camera_fd, &fds); FD_SET(camera_fd, &fds);
@@ -725,10 +794,7 @@ int main(void) {
continue; continue;
} }
CLEAR(buf); if (dequeue_latest_buffer(camera_fd, &buf) != 0) {
buf.type = V4L2_BUF_TYPE_VIDEO_CAPTURE;
buf.memory = V4L2_MEMORY_MMAP;
if (ioctl(camera_fd, VIDIOC_DQBUF, &buf) < 0) {
if (errno == EAGAIN) { if (errno == EAGAIN) {
continue; continue;
} }
@@ -736,6 +802,13 @@ int main(void) {
break; break;
} }
now_ms = monotonic_ms();
if (now_ms < next_deadline_ms) {
drop_reason = "paced_drop";
goto requeue_and_report;
}
next_deadline_ms = now_ms + (1000.0 / (double) fps);
if (decode_mjpeg_frame(decoder, (uint8_t *) buffers[buf.index].start, (int) buf.bytesused, &decoded_frame) != 0) { if (decode_mjpeg_frame(decoder, (uint8_t *) buffers[buf.index].start, (int) buf.bytesused, &decoded_frame) != 0) {
drop_reason = "decode_failed"; drop_reason = "decode_failed";
goto requeue_and_report; goto requeue_and_report;

View File

@@ -1,6 +1,6 @@
transport: transport:
server_addr: "81.70.156.140:10909" server_addr: ""
relay_via: "106.55.173.235:10909" relay_via: "81.70.156.140:10909"
bind_ip: "" bind_ip: ""
bind_device: "" bind_device: ""
@@ -17,7 +17,7 @@ control_receiver:
queue_capacity: 256 queue_capacity: 256
video_sender: video_sender:
enabled: false enabled: true
peer_id: "peer-b-video" peer_id: "peer-b-video"
target_peer: "peer-a-video" target_peer: "peer-a-video"
binary_path: "bin/b_side_video_sender" binary_path: "bin/b_side_video_sender"
@@ -39,7 +39,7 @@ daemon:
worker_restart_delay_ms: 2000 worker_restart_delay_ms: 2000
policy: policy:
mode: "auto" mode: "manual"
health_window_ms: 2000 health_window_ms: 2000
green_srtt_ms: 30 green_srtt_ms: 30
yellow_srtt_ms: 55 yellow_srtt_ms: 55

View File

@@ -10,10 +10,11 @@ from pathlib import Path
import queue import queue
import signal import signal
import socketserver import socketserver
import sys
import threading import threading
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import datetime, timezone
from http import HTTPStatus from http import HTTPStatus
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
from typing import Any from typing import Any
@@ -25,13 +26,13 @@ from .control_codec import ANALOG_EVENT_CODES, EVENT_NAME_TO_ID, make_control_pa
def utc_iso_now() -> str: def utc_iso_now() -> str:
return datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z") return datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z")
def load_omnisocket_api(): def load_omnisocket_api():
from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session, VIDEO_DEFAULTS from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, MSG_TYPE_ERROR, Session, VIDEO_DEFAULTS
return CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session, VIDEO_DEFAULTS return CONTROL_DEFAULTS, MSG_TYPE_BINARY, MSG_TYPE_ERROR, Session, VIDEO_DEFAULTS
def _merge_kcp_defaults(defaults: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: def _merge_kcp_defaults(defaults: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
@@ -49,7 +50,7 @@ def _load_config(config_path: str | None) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as file: with path.open("r", encoding="utf-8") as file:
raw = yaml.safe_load(file) or {} raw = yaml.safe_load(file) or {}
control_defaults, _msg_type_binary, _session_cls, video_defaults = load_omnisocket_api() control_defaults, _msg_type_binary, _msg_type_error, _session_cls, video_defaults = load_omnisocket_api()
transport = dict(raw.get("transport", {})) transport = dict(raw.get("transport", {}))
control = dict(raw.get("control_sender", {})) control = dict(raw.get("control_sender", {}))
@@ -159,7 +160,7 @@ class QueuedControlEvent:
class ControlSessionManager: class ControlSessionManager:
def __init__(self, config: dict[str, Any]) -> None: def __init__(self, config: dict[str, Any]) -> None:
control_defaults, _msg_type_binary, session_cls, _video_defaults = load_omnisocket_api() control_defaults, _msg_type_binary, _msg_type_error, session_cls, _video_defaults = load_omnisocket_api()
transport = config["transport"] transport = config["transport"]
control_cfg = config["control_sender"] control_cfg = config["control_sender"]
daemon_cfg = config["daemon"] daemon_cfg = config["daemon"]
@@ -402,12 +403,13 @@ class ControlSessionManager:
class VideoSessionManager: class VideoSessionManager:
def __init__(self, config: dict[str, Any]) -> None: def __init__(self, config: dict[str, Any]) -> None:
_control_defaults, msg_type_binary, session_cls, video_defaults = load_omnisocket_api() _control_defaults, msg_type_binary, msg_type_error, session_cls, video_defaults = load_omnisocket_api()
transport = config["transport"] transport = config["transport"]
video_cfg = config["video_receiver"] video_cfg = config["video_receiver"]
daemon_cfg = config["daemon"] daemon_cfg = config["daemon"]
self._msg_type_binary = msg_type_binary self._msg_type_binary = msg_type_binary
self._msg_type_error = msg_type_error
self._session_cls = session_cls self._session_cls = session_cls
self._connect_kwargs = { self._connect_kwargs = {
"server_addr": transport["server_addr"], "server_addr": transport["server_addr"],
@@ -516,9 +518,12 @@ class VideoSessionManager:
meta = session.recv_into(buffer, timeout_ms=200) meta = session.recv_into(buffer, timeout_ms=200)
if meta is None: if meta is None:
continue continue
if meta.get("msg_type") != self._msg_type_binary: msg_type = int(meta.get("msg_type", -1))
continue
frame = bytes(buffer[: int(meta["body_len"])]) frame = bytes(buffer[: int(meta["body_len"])])
if msg_type != self._msg_type_binary:
self._disconnect(self._describe_unexpected_message(msg_type, frame))
time.sleep(0.2)
break
jpeg_frame = self._extract_jpeg_frame(frame) jpeg_frame = self._extract_jpeg_frame(frame)
if jpeg_frame is None: if jpeg_frame is None:
with self._lock: with self._lock:
@@ -534,6 +539,14 @@ class VideoSessionManager:
self._disconnect(str(error)) self._disconnect(str(error))
time.sleep(0.2) time.sleep(0.2)
def _describe_unexpected_message(self, msg_type: int, payload: bytes) -> str:
detail = payload.decode("utf-8", errors="replace").strip()
if msg_type == self._msg_type_error:
return f"video session rejected by server: {detail or 'unknown error'}"
if detail:
return f"received unexpected video message type {msg_type}: {detail}"
return f"received unexpected video message type {msg_type}"
def _connect(self) -> None: def _connect(self) -> None:
session = self._session_cls() session = self._session_cls()
try: try:
@@ -958,8 +971,11 @@ class OmniDaemonHTTPHandler(BaseHTTPRequestHandler):
self.send_header("Content-Length", str(len(payload))) self.send_header("Content-Length", str(len(payload)))
self.send_header("Cache-Control", "no-store") self.send_header("Cache-Control", "no-store")
self.send_header("Connection", "keep-alive") self.send_header("Connection", "keep-alive")
self.end_headers() try:
self.wfile.write(payload) self.end_headers()
self.wfile.write(payload)
except (BrokenPipeError, ConnectionResetError):
return
class ASideOmniDaemon: class ASideOmniDaemon:
@@ -1043,6 +1059,15 @@ def main(argv: list[str] | None = None) -> None:
args = parser.parse_args(argv) args = parser.parse_args(argv)
app = ASideOmniDaemon(config_path=args.config_path) app = ASideOmniDaemon(config_path=args.config_path)
print(
(
"A-side OmniDaemon starting "
f"(config={app._config['config_path']}, "
f"socket={app._config['daemon']['socket_path']})"
),
file=sys.stderr,
flush=True,
)
def _handle_signal(_signum: int, _frame: Any) -> None: def _handle_signal(_signum: int, _frame: Any) -> None:
app.stop() app.stop()
@@ -1052,7 +1077,17 @@ def main(argv: list[str] | None = None) -> None:
signal.signal(signal.SIGTERM, _handle_signal) signal.signal(signal.SIGTERM, _handle_signal)
try: try:
app.serve_forever() app.start()
print(
(
"A-side OmniDaemon ready "
f"(state: curl --unix-socket {app.socket_path} http://localhost/v1/state)"
),
file=sys.stderr,
flush=True,
)
assert app._server is not None
app._server.serve_forever()
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
finally: finally:

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import argparse import argparse
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
from datetime import UTC, datetime from datetime import datetime, timezone
from http import HTTPStatus from http import HTTPStatus
from http.server import BaseHTTPRequestHandler from http.server import BaseHTTPRequestHandler
import json import json
@@ -16,6 +16,7 @@ import signal
import socket import socket
import socketserver import socketserver
import subprocess import subprocess
import sys
import threading import threading
import time import time
from typing import Any from typing import Any
@@ -34,13 +35,13 @@ from . import (
def utc_iso_now() -> str: def utc_iso_now() -> str:
return datetime.now(UTC).isoformat(timespec="seconds").replace("+00:00", "Z") return datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z")
def load_omnisocket_api(): def load_omnisocket_api():
from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session from omnisocket import CONTROL_DEFAULTS, MSG_TYPE_BINARY, MSG_TYPE_ERROR, Session
return CONTROL_DEFAULTS, MSG_TYPE_BINARY, Session return CONTROL_DEFAULTS, MSG_TYPE_BINARY, MSG_TYPE_ERROR, Session
def _merge_kcp_defaults(defaults: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: def _merge_kcp_defaults(defaults: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
@@ -114,7 +115,7 @@ def _load_config(config_path: str | None) -> dict[str, Any]:
with path.open("r", encoding="utf-8") as file: with path.open("r", encoding="utf-8") as file:
raw = yaml.safe_load(file) or {} raw = yaml.safe_load(file) or {}
control_defaults, _msg_type_binary, _session_cls = load_omnisocket_api() control_defaults, _msg_type_binary, _msg_type_error, _session_cls = load_omnisocket_api()
transport = dict(raw.get("transport", {})) transport = dict(raw.get("transport", {}))
control = dict(raw.get("control_receiver", {})) control = dict(raw.get("control_receiver", {}))
video = dict(raw.get("video_sender", {})) video = dict(raw.get("video_sender", {}))
@@ -192,12 +193,13 @@ def _load_config(config_path: str | None) -> dict[str, Any]:
class ControlRecvManager: class ControlRecvManager:
def __init__(self, config: dict[str, Any]) -> None: def __init__(self, config: dict[str, Any]) -> None:
control_defaults, msg_type_binary, session_cls = load_omnisocket_api() control_defaults, msg_type_binary, msg_type_error, session_cls = load_omnisocket_api()
transport = config["transport"] transport = config["transport"]
control_cfg = config["control_receiver"] control_cfg = config["control_receiver"]
daemon_cfg = config["daemon"] daemon_cfg = config["daemon"]
self._msg_type_binary = msg_type_binary self._msg_type_binary = msg_type_binary
self._msg_type_error = msg_type_error
self._session_cls = session_cls self._session_cls = session_cls
self._connect_kwargs = { self._connect_kwargs = {
"server_addr": transport["server_addr"], "server_addr": transport["server_addr"],
@@ -306,6 +308,7 @@ class ControlRecvManager:
if msg_type != self._msg_type_binary: if msg_type != self._msg_type_binary:
with self._lock: with self._lock:
self._ignored_non_binary += 1 self._ignored_non_binary += 1
self._disconnect(self._describe_unexpected_message(msg_type, payload))
continue continue
if len(payload) != CONTROL_PACKET_STRUCT.size: if len(payload) != CONTROL_PACKET_STRUCT.size:
with self._lock: with self._lock:
@@ -345,6 +348,14 @@ class ControlRecvManager:
except Exception: except Exception:
pass pass
def _describe_unexpected_message(self, msg_type: int, payload: bytes) -> str:
detail = payload.decode("utf-8", errors="replace").strip()
if msg_type == self._msg_type_error:
return f"control session rejected by server: {detail or 'unknown error'}"
if detail:
return f"received unexpected control message type {msg_type}: {detail}"
return f"received unexpected control message type {msg_type}"
def _enqueue_packet(self, payload: bytes) -> None: def _enqueue_packet(self, payload: bytes) -> None:
try: try:
self._queue.put_nowait(payload) self._queue.put_nowait(payload)
@@ -540,6 +551,16 @@ class VideoWorkerManager:
def start(self) -> None: def start(self) -> None:
if not self._enabled: if not self._enabled:
return return
if not os.path.exists(self._video_cfg["binary_path"]):
print(
(
"B-side video worker binary missing: "
f"{self._video_cfg['binary_path']} "
"(run `make b_side_video_sender` in OmniSocketGo)"
),
file=sys.stderr,
flush=True,
)
self._thread.start() self._thread.start()
def stop(self) -> None: def stop(self) -> None:
@@ -650,16 +671,6 @@ class VideoWorkerManager:
command_read_fd, command_write_fd = os.pipe() command_read_fd, command_write_fd = os.pipe()
telemetry_read_fd, telemetry_write_fd = os.pipe() telemetry_read_fd, telemetry_write_fd = os.pipe()
def _preexec() -> None:
os.dup2(command_read_fd, 3)
os.dup2(telemetry_write_fd, 4)
for fd in (command_read_fd, command_write_fd, telemetry_read_fd, telemetry_write_fd):
if fd not in (3, 4):
try:
os.close(fd)
except OSError:
pass
env = dict(os.environ) env = dict(os.environ)
env.update( env.update(
{ {
@@ -675,6 +686,8 @@ class VideoWorkerManager:
"OMNI_VIDEO_OUTPUT_WIDTH": str(self._video_cfg["output_width"]), "OMNI_VIDEO_OUTPUT_WIDTH": str(self._video_cfg["output_width"]),
"OMNI_VIDEO_OUTPUT_HEIGHT": str(self._video_cfg["output_height"]), "OMNI_VIDEO_OUTPUT_HEIGHT": str(self._video_cfg["output_height"]),
"OMNI_VIDEO_STATS_INTERVAL_MS": str(self._video_cfg["stats_interval_ms"]), "OMNI_VIDEO_STATS_INTERVAL_MS": str(self._video_cfg["stats_interval_ms"]),
"OMNI_WORKER_CONTROL_FD": str(command_read_fd),
"OMNI_WORKER_TELEMETRY_FD": str(telemetry_write_fd),
} }
) )
with self._lock: with self._lock:
@@ -692,7 +705,6 @@ class VideoWorkerManager:
env=env, env=env,
close_fds=True, close_fds=True,
pass_fds=(command_read_fd, telemetry_write_fd), pass_fds=(command_read_fd, telemetry_write_fd),
preexec_fn=_preexec,
) )
except Exception as error: # pragma: no cover - runtime integration except Exception as error: # pragma: no cover - runtime integration
for fd in (command_read_fd, command_write_fd, telemetry_read_fd, telemetry_write_fd): for fd in (command_read_fd, command_write_fd, telemetry_read_fd, telemetry_write_fd):
@@ -1196,8 +1208,11 @@ class OmniDaemonHTTPHandler(BaseHTTPRequestHandler):
self.send_header("Content-Length", str(len(body))) self.send_header("Content-Length", str(len(body)))
self.send_header("Cache-Control", "no-store") self.send_header("Cache-Control", "no-store")
self.send_header("Connection", "keep-alive") self.send_header("Connection", "keep-alive")
self.end_headers() try:
self.wfile.write(body) self.end_headers()
self.wfile.write(body)
except (BrokenPipeError, ConnectionResetError):
return
class BSideOmniDaemon: class BSideOmniDaemon:
@@ -1259,6 +1274,14 @@ class BSideOmniDaemon:
daemon=True, daemon=True,
) )
self._server_thread.start() self._server_thread.start()
print(
(
"B-side OmniDaemon ready "
f"(state: curl --unix-socket {self.socket_path} http://localhost/v1/state)"
),
file=sys.stderr,
flush=True,
)
self._server_thread.join() self._server_thread.join()
def get_state(self) -> dict[str, Any]: def get_state(self) -> dict[str, Any]:
@@ -1299,6 +1322,17 @@ def main(argv: list[str] | None = None) -> None:
args = parser.parse_args(argv) args = parser.parse_args(argv)
app = BSideOmniDaemon(config_path=args.config_path) app = BSideOmniDaemon(config_path=args.config_path)
print(
(
"B-side OmniDaemon starting "
f"(config={app._config['config_path']}, "
f"socket={app._config['daemon']['socket_path']}, "
f"ctrl_socket={app._config['daemon']['ctrl_socket_path']}, "
f"video_enabled={app._config['video_sender']['enabled']})"
),
file=sys.stderr,
flush=True,
)
def _handle_signal(_signum: int, _frame: Any) -> None: def _handle_signal(_signum: int, _frame: Any) -> None:
app.stop() app.stop()

View File

@@ -36,6 +36,9 @@ setup(
name="omnisocket", name="omnisocket",
version="0.1.0", version="0.1.0",
packages=["omnisocket", "omnisocket_a_side", "omnisocket_b_side"], packages=["omnisocket", "omnisocket_a_side", "omnisocket_b_side"],
install_requires=[
"PyYAML>=6.0",
],
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"omnisocket-a-side-daemon=omnisocket_a_side.daemon:main", "omnisocket-a-side-daemon=omnisocket_a_side.daemon:main",

0
scripts/start_b_side.sh Normal file → Executable file
View File

View File

@@ -5,6 +5,13 @@
#define UDP_RELAY_BUF_SIZE (64U * 1024U) #define UDP_RELAY_BUF_SIZE (64U * 1024U)
typedef struct udp_relay_client_entry {
struct udp_relay_client_entry *next;
uint32_t conv;
struct sockaddr_storage addr;
socklen_t addr_len;
} udp_relay_client_entry_t;
struct udp_relay { struct udp_relay {
int downstream_fd; int downstream_fd;
int upstream_fd; int upstream_fd;
@@ -12,9 +19,10 @@ struct udp_relay {
socklen_t upstream_addr_len; socklen_t upstream_addr_len;
char downstream_local_addr[OMNI_MAX_ADDR_TEXT]; char downstream_local_addr[OMNI_MAX_ADDR_TEXT];
char upstream_local_addr[OMNI_MAX_ADDR_TEXT]; char upstream_local_addr[OMNI_MAX_ADDR_TEXT];
struct sockaddr_storage client_addr; struct sockaddr_storage last_client_addr;
socklen_t client_addr_len; socklen_t last_client_addr_len;
int has_client; int has_last_client;
udp_relay_client_entry_t *clients;
pthread_mutex_t lock; pthread_mutex_t lock;
pthread_mutex_t log_mu; pthread_mutex_t log_mu;
pthread_mutex_t state_mu; pthread_mutex_t state_mu;
@@ -131,22 +139,55 @@ static void udp_relay_note_result(udp_relay_t *relay, int rc, int errnum) {
pthread_mutex_unlock(&relay->state_mu); pthread_mutex_unlock(&relay->state_mu);
} }
static void udp_relay_record_client(udp_relay_t *relay, const struct sockaddr_storage *addr, socklen_t addr_len) { static udp_relay_client_entry_t *udp_relay_find_client_locked(udp_relay_t *relay, uint32_t conv) {
udp_relay_client_entry_t *entry;
for (entry = relay->clients; entry != NULL; entry = entry->next) {
if (entry->conv == conv) {
return entry;
}
}
return NULL;
}
static void udp_relay_record_client(udp_relay_t *relay, int has_conv, uint32_t conv, const struct sockaddr_storage *addr, socklen_t addr_len) {
pthread_mutex_lock(&relay->lock); pthread_mutex_lock(&relay->lock);
memcpy(&relay->client_addr, addr, sizeof(*addr)); memcpy(&relay->last_client_addr, addr, sizeof(*addr));
relay->client_addr_len = addr_len; relay->last_client_addr_len = addr_len;
relay->has_client = 1; relay->has_last_client = 1;
if (has_conv) {
udp_relay_client_entry_t *entry = udp_relay_find_client_locked(relay, conv);
if (entry == NULL) {
entry = (udp_relay_client_entry_t *) calloc(1, sizeof(*entry));
if (entry != NULL) {
entry->conv = conv;
entry->next = relay->clients;
relay->clients = entry;
}
}
if (entry != NULL) {
memcpy(&entry->addr, addr, sizeof(*addr));
entry->addr_len = addr_len;
}
}
pthread_mutex_unlock(&relay->lock); pthread_mutex_unlock(&relay->lock);
} }
static int udp_relay_copy_client(udp_relay_t *relay, struct sockaddr_storage *addr, socklen_t *addr_len) { static int udp_relay_copy_client(udp_relay_t *relay, int has_conv, uint32_t conv, struct sockaddr_storage *addr, socklen_t *addr_len) {
int has_client; int has_client = 0;
pthread_mutex_lock(&relay->lock); pthread_mutex_lock(&relay->lock);
has_client = relay->has_client; if (has_conv) {
if (has_client) { udp_relay_client_entry_t *entry = udp_relay_find_client_locked(relay, conv);
memcpy(addr, &relay->client_addr, sizeof(*addr)); if (entry != NULL) {
*addr_len = relay->client_addr_len; memcpy(addr, &entry->addr, sizeof(*addr));
*addr_len = entry->addr_len;
has_client = 1;
}
} else if (relay->has_last_client) {
memcpy(addr, &relay->last_client_addr, sizeof(*addr));
*addr_len = relay->last_client_addr_len;
has_client = 1;
} }
pthread_mutex_unlock(&relay->lock); pthread_mutex_unlock(&relay->lock);
return has_client; return has_client;
@@ -160,6 +201,8 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) {
struct sockaddr_storage source; struct sockaddr_storage source;
socklen_t source_len = sizeof(source); socklen_t source_len = sizeof(source);
ssize_t n = recvfrom(relay->downstream_fd, buffer, sizeof(buffer), 0, (struct sockaddr *) &source, &source_len); ssize_t n = recvfrom(relay->downstream_fd, buffer, sizeof(buffer), 0, (struct sockaddr *) &source, &source_len);
int has_conv = 0;
uint32_t conv = 0;
if (n < 0) { if (n < 0) {
int errnum = errno; int errnum = errno;
@@ -174,7 +217,8 @@ static void *udp_relay_forward_downstream_to_upstream(void *arg) {
return NULL; return NULL;
} }
udp_relay_record_client(relay, &source, source_len); udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL);
udp_relay_record_client(relay, has_conv, conv, &source, source_len);
udp_relay_print_packet(relay, "relay_downstream_rx", relay->downstream_local_addr, &source, source_len, buffer, (size_t) n); udp_relay_print_packet(relay, "relay_downstream_rx", relay->downstream_local_addr, &source, source_len, buffer, (size_t) n);
for (;;) { for (;;) {
if (send(relay->upstream_fd, buffer, (size_t) n, 0) >= 0) { if (send(relay->upstream_fd, buffer, (size_t) n, 0) >= 0) {
@@ -205,6 +249,8 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) {
struct sockaddr_storage client_addr; struct sockaddr_storage client_addr;
socklen_t client_addr_len = 0; socklen_t client_addr_len = 0;
ssize_t n = recv(relay->upstream_fd, buffer, sizeof(buffer), 0); ssize_t n = recv(relay->upstream_fd, buffer, sizeof(buffer), 0);
int has_conv = 0;
uint32_t conv = 0;
if (n < 0) { if (n < 0) {
int errnum = errno; int errnum = errno;
@@ -220,7 +266,8 @@ static void *udp_relay_forward_upstream_to_downstream(void *arg) {
} }
udp_relay_print_packet(relay, "relay_upstream_rx", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n); udp_relay_print_packet(relay, "relay_upstream_rx", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n);
if (!udp_relay_copy_client(relay, &client_addr, &client_addr_len)) { udp_relay_parse_kcp_summary(buffer, (size_t) n, &has_conv, &conv, NULL);
if (!udp_relay_copy_client(relay, has_conv, conv, &client_addr, &client_addr_len)) {
udp_relay_print_packet(relay, "relay_upstream_drop_no_client", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n); udp_relay_print_packet(relay, "relay_upstream_drop_no_client", relay->upstream_local_addr, &relay->upstream_addr, relay->upstream_addr_len, buffer, (size_t) n);
continue; continue;
} }
@@ -404,11 +451,18 @@ int udp_relay_close(udp_relay_t *relay) {
} }
void udp_relay_free(udp_relay_t *relay) { void udp_relay_free(udp_relay_t *relay) {
udp_relay_client_entry_t *entry;
udp_relay_client_entry_t *next;
if (relay == NULL) { if (relay == NULL) {
return; return;
} }
udp_relay_close(relay); udp_relay_close(relay);
udp_relay_join_threads(relay); udp_relay_join_threads(relay);
for (entry = relay->clients; entry != NULL; entry = next) {
next = entry->next;
free(entry);
}
pthread_mutex_destroy(&relay->lock); pthread_mutex_destroy(&relay->lock);
pthread_mutex_destroy(&relay->log_mu); pthread_mutex_destroy(&relay->log_mu);
pthread_cond_destroy(&relay->state_cond); pthread_cond_destroy(&relay->state_cond);