radiance/cmd/proxy/proxy.go

186 lines
4.7 KiB
Go

package main
import (
"bytes"
"context"
"flag"
"github.com/LiamHaworth/go-tproxy"
"github.com/certusone/tpuproxy/pkg/endpoints"
"github.com/certusone/tpuproxy/pkg/netlink"
"github.com/certusone/tpuproxy/pkg/nftables"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/prometheus/client_golang/prometheus/promhttp"
"golang.org/x/sys/unix"
"k8s.io/klog/v2"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/signal"
"strconv"
"strings"
"sync/atomic"
"time"
)
var (
flagDebugAddr = flag.String("debug-addr", ":6060", "Metrics and pprof listen address")
flagIface = flag.String("iface", "", "External interface to receive packets from")
flagPorts = flag.String("ports", "", "Destination ports to proxy (comma-separated), asks local RPC if empty")
metricPacketsCount = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "tproxy_packets_count",
Help: "Number of packets received by the proxy",
}, []string{"port"})
metricBytesCount = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "tproxy_bytes_count",
Help: "Number of bytes received by the proxy",
}, []string{"port"})
)
func main() {
flag.Parse()
if *flagIface == "lo" {
klog.Exitf("proxying lo would lead to a loopback packet loop")
}
if *flagIface == "" {
klog.Exitf("no interface specified, use -iface to specify one")
}
dst, err := netlink.GetInterfaceIP(*flagIface)
if err != nil {
klog.Exit("failed to get IP: ", err)
}
klog.Infof("interface %s has primary IP %s", *flagIface, dst)
ports := make([]uint16, 0)
if *flagPorts == "" {
klog.Infof("no ports specified, asking local RPC for ports")
ports, err = endpoints.GetNodeTPUPorts(context.Background(), endpoints.RPCLocalhost, dst)
if err != nil {
klog.Exit("failed to get ports: ", err)
}
klog.Infof("found ports: %v", ports)
} else {
for _, port := range strings.Split(*flagPorts, ",") {
p, err := strconv.ParseUint(port, 10, 16)
if err != nil {
klog.Exit("failed to parse port: ", err)
}
ports = append(ports, uint16(p))
}
}
go func() {
http.Handle("/metrics", promhttp.Handler())
klog.Infof("Starting pprof and Prometheus server on %s", *flagDebugAddr)
klog.Fatal(http.ListenAndServe(*flagDebugAddr, nil))
}()
// Get hostname
hostname, err := os.Hostname()
if err != nil {
klog.Fatalf("Failed to get hostname: %v", err)
}
klog.Infof("Running on %s", hostname)
addr, err := net.ResolveUDPAddr("udp", ":0")
if err != nil {
klog.Exitf("Failed to resolve address: %v", err)
}
conn, err := tproxy.ListenUDP("udp", addr)
if err != nil {
klog.Exitf("Failed to listen on %s: %v", addr, err)
}
localPort := uint16(conn.LocalAddr().(*net.UDPAddr).Port)
defer conn.Close()
klog.Infof("Listening on %s", conn.LocalAddr())
go listen(conn)
if err := nftables.EnsureKernelModules(); err != nil {
klog.Exitf("Failed to ensure kernel modules: %v", err)
}
if err := nftables.InsertProxyChain(ports, localPort, *flagIface); err != nil {
klog.Exitf("Failed to insert nft tproxy chain: %v", err)
}
defer func() {
err := nftables.DeleteProxyChain()
if err != nil {
klog.Warningf("Failed to delete nft tproxy chain: %v", err)
}
klog.Infof("Deleted nft tproxy chain")
}()
klog.Infof("Inserted nft tproxy chain")
sigint := make(chan os.Signal, 1)
signal.Notify(sigint, unix.SIGINT)
signal.Notify(sigint, unix.SIGTERM)
<-sigint
klog.Infof("Shutting down")
}
func listen(conn *net.UDPConn) {
var inBytes *uint64
var inPackets *uint64
inBytes = new(uint64)
inPackets = new(uint64)
// Periodically log stats
go func() {
for {
klog.Infof("InBytes: %d, InPackets: %d", atomic.LoadUint64(inBytes), atomic.LoadUint64(inPackets))
atomic.StoreUint64(inBytes, 0)
atomic.StoreUint64(inPackets, 0)
time.Sleep(time.Second)
}
}()
for {
buf := make([]byte, 1024)
n, src, dst, err := tproxy.ReadFromUDP(conn, buf)
if err != nil {
klog.Errorf("Failed to read from UDP: %v", err)
continue
}
if bytes.Equal(src.IP, dst.IP) && src.Port == dst.Port {
klog.V(2).Infof("src and dst are identical, dropping packet")
continue
}
atomic.AddUint64(inBytes, uint64(n))
atomic.AddUint64(inPackets, 1)
metricBytesCount.WithLabelValues(strconv.Itoa(int(dst.Port))).Add(float64(n))
metricPacketsCount.WithLabelValues(strconv.Itoa(int(dst.Port))).Add(1)
go handlePacket(conn, buf, src, dst)
}
}
func handlePacket(conn *net.UDPConn, buf []byte, src, dst *net.UDPAddr) {
klog.V(2).Infof("Received %d bytes from %s", len(buf), src)
_, err := conn.WriteToUDP(buf, dst)
if err != nil {
klog.Errorf("Failed to write to UDP: %v", err)
return
}
klog.V(2).Infof("Sent %d bytes to %s", len(buf), dst)
}