添加多种模式

This commit is contained in:
刘河 2018-11-29 19:55:24 +08:00
parent 2463116b37
commit 3ea895feb5
8 changed files with 831 additions and 215 deletions

View File

@ -1,3 +1,4 @@
<<<<<<< Updated upstream
# easyProxy
轻量级、较高性能http代理服务器主要应用与内网穿透。支持多站点配置、客户端与服务端连接中断自动重连多路传输大大的提高请求处理速度go语言编写无第三方依赖经过测试内存占用小普通场景下仅占用10m内存。
@ -135,12 +136,34 @@ server {
如需开启请加配置文件Replace值设置为1
>注意:开启可能导致不应该被替换的内容被替换,请谨慎开启
=======
# rproxy
简单的反向代理用于内网穿透
**特别注意,此工具只适合小文件类的访问测试,用来做做数据调试。当初也只是用于微信公众号开发,所以定位也是如此**
## 前言
最近周末闲来无事想起了做下微信公共号的开发但微信限制只能80端口的自己用的城中村的那种宽带共用一个公网没办法自己用路由做端口映射。自己的服务器在腾讯云上每次都要编译完后用ftp上传再进行调试非常的浪费时间。 一时间又不知道上哪找一个符合我的这种要求的工具,就索性自己构思了下,整个工作流程大致为:
## 工作原理
> 外部请求自己服务器上的HTTP服务端 -> 将数据传递给Socket服务器 -> Socket服务器将数据发送至已连接的Socket客户端 -> Socket客户端收到数据 -> 使用http请求本地http服务端 -> 本地http服务端处理相关后返回 -> Socket客户端将返回的数据发送至Socket服务端 -> Socket服务端解析出数据后原路返回至外部请求的HTTP
## 使用方法
> 1、go get github.com/ying32/rproxy
> 2、go build
> 3、服务端运行runsvr.bat或者runsvr.sh
> 4、客户端运行runcli.bat或者runcli.sh
## 命令行说明
> --tcpport Socket连接或者监听的端口
> --httpport 当mode为server时为服务端监听端口当为mode为client时为转发至本地客户端的端口
> --mode 启动模式可选为client、server默认为client
> --svraddr 当mode为client时有效为连接服务器的地址不需要填写端口
> --vkey 客户端与服务端建立连接时校验的加密key简单的。
>>>>>>> Stashed changes
## 操作系统支持
支持Windows、Linux、MacOSX等无第三方依赖库。
## 二级域名泛解析配置详细教程
[详细教程](https://github.com/cnlh/easyProxy/wiki/%E4%BD%BF%E7%94%A8%E6%95%99%E7%A8%8B)
支持Windows、Linux、MacOSX等无第三方依赖库。
## 二进制下载
https://github.com/ying32/rproxy/releases/tag/v0.4

156
client.go
View File

@ -1,20 +1,15 @@
package main
import (
"encoding/binary"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"strings"
"sync"
"time"
)
var (
disabledRedirect = errors.New("disabled redirect.")
)
type TRPClient struct {
svrAddr string
tcpNum int
@ -28,56 +23,58 @@ func NewRPClient(svraddr string, tcpNum int) *TRPClient {
return c
}
func (c *TRPClient) Start() error {
for i := 0; i < c.tcpNum; i++ {
go c.newConn()
func (s *TRPClient) Start() error {
for i := 0; i < s.tcpNum; i++ {
go s.newConn()
}
for {
time.Sleep(5 * time.Second)
time.Sleep(time.Second * 5)
}
return nil
}
func (c *TRPClient) newConn() error {
c.Lock()
conn, err := net.Dial("tcp", c.svrAddr)
//新建
func (s *TRPClient) newConn() error {
s.Lock()
conn, err := net.Dial("tcp", s.svrAddr)
if err != nil {
log.Println("连接服务端失败,五秒后将重连")
time.Sleep(time.Second * 5)
c.Unlock()
c.newConn()
s.Unlock()
go s.newConn()
return err
}
c.Unlock()
conn.(*net.TCPConn).SetKeepAlive(true)
conn.(*net.TCPConn).SetKeepAlivePeriod(time.Duration(2 * time.Second))
return c.process(conn)
s.Unlock()
return s.process(NewConn(conn))
}
func (c *TRPClient) werror(conn net.Conn) {
conn.Write([]byte("msg0"))
}
func (c *TRPClient) process(conn net.Conn) error {
if _, err := conn.Write(getverifyval()); err != nil {
func (s *TRPClient) process(c *Conn) error {
c.SetAlive()
if _, err := c.Write(getverifyval()); err != nil {
return err
}
val := make([]byte, 4)
c.wMain()
for {
_, err := conn.Read(val)
flags, err := c.ReadFlag()
if err != nil {
log.Println("服务端断开,五秒后将重连", err)
time.Sleep(5 * time.Second)
go c.newConn()
return err
go s.newConn()
break
}
flags := string(val)
switch flags {
case "vkey":
case VERIFY_EER:
log.Fatal("vkey不正确,请检查配置文件")
case "sign":
c.deal(conn)
case "msg0":
case RES_SIGN: //代理请求模式
if err := s.dealHttp(c); err != nil {
log.Println(err)
return err
}
case WORK_CHAN: //隧道模式每次开启10个加快连接速度
for i := 0; i < 10; i++ {
go s.dealChan()
}
case RES_MSG:
log.Println("服务端返回错误。")
default:
log.Println("无法解析该错误。")
@ -85,69 +82,64 @@ func (c *TRPClient) process(conn net.Conn) error {
}
return nil
}
func (c *TRPClient) deal(conn net.Conn) error {
val := make([]byte, 4)
_, err := conn.Read(val)
nlen := binary.LittleEndian.Uint32(val)
log.Println("收到服务端数据,长度:", nlen)
if nlen <= 0 {
log.Println("数据长度错误。")
c.werror(conn)
return errors.New("数据长度错误")
//隧道模式处理
func (s *TRPClient) dealChan() error {
//创建一个tcp连接
conn, err := net.Dial("tcp", s.svrAddr)
//验证
if _, err := conn.Write(getverifyval()); err != nil {
return err
}
raw := make([]byte, nlen)
n, err := conn.Read(raw)
//默认长连接保持
c := NewConn(conn)
c.SetAlive()
//写标志
c.wChan()
//获取连接的host
host, err := c.GetHostFromConn()
if err != nil {
return err
}
if n != int(nlen) {
log.Printf("读取服务端数据长度错误,已经读取%dbyte总长度%d字节\n", n, nlen)
c.werror(conn)
return errors.New("读取服务端数据长度错误")
//与目标建立连接
server, err := net.Dial("tcp", host)
if err != nil {
return err
}
//创建成功后io.copy
go io.Copy(server, c)
io.Copy(c, server)
return nil
}
//http模式处理
func (s *TRPClient) dealHttp(c *Conn) error {
nlen, err := c.GetLen()
if err != nil {
c.wError()
return err
}
raw, err := c.ReadLen(int(nlen))
if err != nil {
c.wError()
return err
}
req, err := DecodeRequest(raw)
if err != nil {
log.Println("DecodeRequest错误", err)
c.werror(conn)
c.wError()
return err
}
rawQuery := ""
if req.URL.RawQuery != "" {
rawQuery = "?" + req.URL.RawQuery
}
log.Println(req.URL.Path + rawQuery)
client := new(http.Client)
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return disabledRedirect
}
resp, err := client.Do(req)
disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
if err != nil && !disRedirect {
log.Println("请求本地客户端错误:", err)
c.werror(conn)
return err
}
if !disRedirect {
defer resp.Body.Close()
} else {
resp.Body = nil
resp.ContentLength = 0
}
respBytes, err := EncodeResponse(resp)
respBytes, err := GetEncodeResponse(req)
if err != nil {
log.Println("EncodeResponse错误", err)
c.werror(conn)
c.wError()
return err
}
n, err = conn.Write(respBytes)
n, err := c.Write(respBytes)
if err != nil {
log.Println("发送数据错误,错误:", err)
return err
}
if n != len(respBytes) {
log.Printf("发送数据长度错误,已经发送:%dbyte总字节长%dbyte\n", n, len(respBytes))
} else {
log.Printf("本次请求成功完成,共发送:%dbyte\n", n)
return errors.New(fmt.Sprintf("发送数据长度错误,已经发送:%dbyte总字节长%dbyte\n", n, len(respBytes)))
}
return nil
}

118
conn.go Normal file
View File

@ -0,0 +1,118 @@
package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
)
type Conn struct {
conn net.Conn
}
func NewConn(conn net.Conn) *Conn {
c := new(Conn)
c.conn = conn
return c
}
//读取指定内容长度
func (s *Conn) ReadLen(len int) ([]byte, error) {
raw := make([]byte, 0)
buff := make([]byte, 1024)
c := 0
for {
clen, err := s.conn.Read(buff)
if err != nil && err != io.EOF {
return raw, err
}
raw = append(raw, buff[:clen]...)
if c += clen; c >= len {
break
}
}
if c != len {
return raw, errors.New(fmt.Sprintf("已读取长度错误,已读取%dbyte需要读取%dbyte。", c, len))
}
return raw, nil
}
//获取长度
func (s *Conn) GetLen() (int, error) {
val := make([]byte, 4)
_, err := s.conn.Read(val)
if err != nil {
return 0, err
}
nlen := binary.LittleEndian.Uint32(val)
if nlen <= 0 {
return 0, errors.New("数据长度错误")
}
return int(nlen), nil
}
//读取flag
func (s *Conn) ReadFlag() (string, error) {
val := make([]byte, 4)
_, err := s.conn.Read(val)
if err != nil {
return "", err
}
return string(val), err
}
//读取host
func (s *Conn) GetHostFromConn() (string, error) {
len, err := s.GetLen()
if err != nil {
return "", err
}
hostByte := make([]byte, len)
_, err = s.conn.Read(hostByte)
if err != nil {
return "", err
}
return string(hostByte), nil
}
//获取host
func (s *Conn) WriteHost(host string) (int, error) {
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, int32(len([]byte(host))))
binary.Write(raw, binary.LittleEndian, []byte(host))
return s.Write(raw.Bytes())
}
//设置连接为长连接
func (s *Conn) SetAlive() {
conn := s.conn.(*net.TCPConn)
conn.SetReadDeadline(time.Time{})
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
}
func (s *Conn) Close() error {
return s.conn.Close()
}
func (s *Conn) Write(b []byte) (int, error) {
return s.conn.Write(b)
}
func (s *Conn) Read(b []byte) (int, error) {
return s.conn.Read(b)
}
func (s *Conn) wError() {
s.conn.Write([]byte(RES_MSG))
}
func (s *Conn) wMain() {
s.conn.Write([]byte(WORK_MAIN))
}
func (s *Conn) wChan() {
s.conn.Write([]byte(WORK_CHAN))
}

36
main.go
View File

@ -7,13 +7,14 @@ import (
)
var (
configPath = flag.String("config", "config.json", "配置文件路径")
tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口")
httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口当为mode为client时为转发至本地客户端的端口")
rpMode = flag.String("mode", "client", "启动模式可选为client、server")
verifyKey = flag.String("vkey", "", "验证密钥")
config Config
err error
configPath = flag.String("config", "config.json", "配置文件路径")
tcpPort = flag.Int("tcpport", 8284, "Socket连接或者监听的端口")
httpPort = flag.Int("httpport", 8024, "当mode为server时为服务端监听端口当为mode为client时为转发至本地客户端的端口")
rpMode = flag.String("mode", "client", "启动模式可选为client、server")
tunnelTarget = flag.String("target", "10.1.50.203:80", "tunnel模式远程目标")
verifyKey = flag.String("vkey", "", "验证密钥")
config Config
err error
)
func main() {
@ -29,7 +30,7 @@ func main() {
log.Println("客户端启动,连接:", config.Server.Ip, " 端口:", config.Server.Tcp)
cli := NewRPClient(fmt.Sprintf("%s:%d", config.Server.Ip, config.Server.Tcp), config.Server.Num)
cli.Start()
} else if *rpMode == "server" {
} else {
if *verifyKey == "" {
log.Fatalln("必须输入一个验证的key")
}
@ -39,11 +40,20 @@ func main() {
if *httpPort <= 0 || *httpPort >= 65536 {
log.Fatalln("请输入正确的http端口。")
}
log.Println("服务端启动监听tcp服务端端口", *tcpPort, " http服务端端口", *httpPort)
svr := NewRPServer(*tcpPort, *httpPort)
if err := svr.Start(); err != nil {
log.Fatalln(err)
log.Println("服务端启动监听tcp服务端端口", *tcpPort, " 外部服务端端口:", *httpPort)
if *rpMode == "httpServer" {
svr := NewHttpModeServer(*tcpPort, *httpPort)
if err := svr.Start(); err != nil {
log.Fatalln(err)
}
} else if *rpMode == "tunnelServer" {
svr := NewTunnelModeServer(*tcpPort, *httpPort, *tunnelTarget)
if err := svr.Start(); err != nil {
log.Fatalln(err)
}
} else if *rpMode == "sock5Server" {
svr := NewSock5ModeServer(*tcpPort, *httpPort)
svr.Start()
}
defer svr.Close()
}
}

220
server.go
View File

@ -1,8 +1,6 @@
package main
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
@ -10,115 +8,69 @@ import (
"log"
"net"
"net/http"
"sync"
"time"
)
type TRPServer struct {
tcpPort int
const (
VERIFY_EER = "vkey"
WORK_MAIN = "main"
WORK_CHAN = "chan"
RES_SIGN = "sign"
RES_MSG = "msg0"
)
type HttpModeServer struct {
Tunnel
httpPort int
listener *net.TCPListener
connList chan net.Conn
sync.RWMutex
}
func NewRPServer(tcpPort, httpPort int) *TRPServer {
s := new(TRPServer)
s.tcpPort = tcpPort
func NewHttpModeServer(tcpPort, httpPort int) *HttpModeServer {
s := new(HttpModeServer)
s.tunnelPort = tcpPort
s.httpPort = httpPort
s.connList = make(chan net.Conn, 1000)
s.signalList = make(chan *Conn, 1000)
return s
}
func (s *TRPServer) Start() error {
var err error
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tcpPort, ""})
//开始
func (s *HttpModeServer) Start() (error) {
err := s.StartTunnel()
if err != nil {
log.Fatalln("开启客户端失败!", err)
return err
}
go s.httpserver()
return s.tcpserver()
s.startHttpServer()
return nil
}
func (s *TRPServer) Close() error {
if s.listener != nil {
err := s.listener.Close()
s.listener = nil
return err
}
return errors.New("TCP实例未创建")
}
func (s *TRPServer) tcpserver() error {
var err error
for {
conn, err := s.listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
go s.cliProcess(conn)
}
return err
}
func badRequest(w http.ResponseWriter) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
}
func (s *TRPServer) httpserver() {
//开启http端口监听
func (s *HttpModeServer) startHttpServer() {
http.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
retry:
if len(s.connList) == 0 {
badRequest(w)
if len(s.signalList) == 0 {
BadRequest(w)
return
}
conn := <-s.connList
log.Println(r.RequestURI)
err := s.write(r, conn)
conn := <-s.signalList
if err := s.writeRequest(r, conn); err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
err = s.writeResponse(w, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
err = s.read(w, conn)
if err != nil {
log.Println(err)
conn.Close()
goto retry
return
}
s.connList <- conn
conn = nil
s.signalList <- conn
})
log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", s.httpPort), nil))
}
func (s *TRPServer) cliProcess(conn *net.TCPConn) error {
conn.SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
vval := make([]byte, 20)
_, err := conn.Read(vval)
if err != nil {
log.Println("客户端读超时。客户端地址为::", conn.RemoteAddr())
conn.Close()
return err
}
if bytes.Compare(vval, getverifyval()[:]) != 0 {
log.Println("当前客户端连接校验错误,关闭此客户端:", conn.RemoteAddr())
conn.Write([]byte("vkey"))
conn.Close()
return err
}
conn.SetReadDeadline(time.Time{})
log.Println("连接新的客户端:", conn.RemoteAddr())
conn.SetKeepAlive(true)
conn.SetKeepAlivePeriod(time.Duration(2 * time.Second))
s.connList <- conn
return nil
}
func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
//req转为bytes发送给client端
func (s *HttpModeServer) writeRequest(r *http.Request, conn *Conn) error {
raw, err := EncodeRequest(r)
if err != nil {
return err
@ -133,41 +85,21 @@ func (s *TRPServer) write(r *http.Request, conn net.Conn) error {
return nil
}
func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
val := make([]byte, 4)
_, err := conn.Read(val)
//从client读取出Response
func (s *HttpModeServer) writeResponse(w http.ResponseWriter, c *Conn) error {
flags, err := c.ReadFlag()
if err != nil {
return err
}
flags := string(val)
switch flags {
case "sign":
_, err = conn.Read(val)
case RES_SIGN:
nlen, err := c.GetLen()
if err != nil {
return err
}
nlen := int(binary.LittleEndian.Uint32(val))
if nlen == 0 {
return errors.New("读取客户端长度错误。")
}
log.Println("收到客户端数据,需要读取长度:", nlen)
raw := make([]byte, 0)
buff := make([]byte, 1024)
c := 0
for {
clen, err := conn.Read(buff)
if err != nil && err != io.EOF {
return err
}
raw = append(raw, buff[:clen]...)
c += clen
if c >= nlen {
break
}
}
log.Println("读取完成,长度:", c, "实际raw长度", len(raw))
if c != nlen {
return fmt.Errorf("已读取长度错误,已读取%dbyte需要读取%dbyte。", c, nlen)
raw, err := c.ReadLen(nlen)
if err != nil {
return err
}
resp, err := DecodeResponse(raw)
if err != nil {
@ -184,10 +116,70 @@ func (s *TRPServer) read(w http.ResponseWriter, conn net.Conn) (error) {
}
w.WriteHeader(resp.StatusCode)
w.Write(bodyBytes)
case "msg0":
return nil
case RES_MSG:
BadRequest(w)
return errors.New("客户端请求出错")
default:
log.Println("无法解析此错误", string(val))
BadRequest(w)
return errors.New("无法解析此错误")
}
return nil
}
type TunnelModeServer struct {
Tunnel
httpPort int
tunnelTarget string
}
func NewTunnelModeServer(tcpPort, httpPort int, tunnelTarget string) *TunnelModeServer {
s := new(TunnelModeServer)
s.tunnelPort = tcpPort
s.httpPort = httpPort
s.tunnelTarget = tunnelTarget
s.tunnelList = make(chan *Conn, 1000)
s.signalList = make(chan *Conn, 10)
return s
}
//开始
func (s *TunnelModeServer) Start() (error) {
err := s.StartTunnel()
if err != nil {
log.Fatalln("开启客户端失败!", err)
return err
}
s.startTunnelServer()
return nil
}
//隧道模式server
func (s *TunnelModeServer) startTunnelServer() {
listener, err := net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.httpPort, ""})
if err != nil {
log.Fatalln(err)
}
for {
conn, err := listener.AcceptTCP()
if err != nil {
log.Println(err)
continue
}
go s.process(NewConn(conn))
}
}
//监听连接处理
func (s *TunnelModeServer) process(c *Conn) error {
retry:
if len(s.tunnelList) < 10 { //新建通道
go s.newChan()
}
link := <-s.tunnelList
if _, err := link.WriteHost(s.tunnelTarget); err != nil {
goto retry
}
go io.Copy(link, c)
io.Copy(c, link.conn)
return nil
}

236
sock5.go Normal file
View File

@ -0,0 +1,236 @@
package main
import (
"encoding/binary"
"errors"
"io"
"log"
"net"
"strconv"
)
const (
ipV4 = 1
domainName = 3
ipV6 = 4
connectMethod = 1
bindMethod = 2
associateMethod = 3
// The maximum packet size of any udp Associate packet, based on ethernet's max size,
// minus the IP and UDP headers. IPv4 has a 20 byte header, UDP adds an
// additional 4 bytes. This is a total overhead of 24 bytes. Ethernet's
// max packet size is 1500 bytes, 1500 - 24 = 1476.
maxUDPPacketSize = 1476
)
const (
succeeded uint8 = iota
serverFailure
notAllowed
networkUnreachable
hostUnreachable
connectionRefused
ttlExpired
commandNotSupported
addrTypeNotSupported
)
type Sock5ModeServer struct {
Tunnel
httpPort int
}
func (s *Sock5ModeServer) handleRequest(c net.Conn) {
/*
The SOCKS request is formed as follows:
+----+-----+-------+------+----------+----------+
|VER | CMD | RSV | ATYP | DST.ADDR | DST.PORT |
+----+-----+-------+------+----------+----------+
| 1 | 1 | X'00' | 1 | Variable | 2 |
+----+-----+-------+------+----------+----------+
*/
header := make([]byte, 3)
_, err := io.ReadFull(c, header)
if err != nil {
log.Println("illegal request", err)
c.Close()
return
}
switch header[1] {
case connectMethod:
s.handleConnect(c)
case bindMethod:
s.handleBind(c)
case associateMethod:
s.handleUDP(c)
default:
s.sendReply(c, commandNotSupported)
c.Close()
}
}
func (s *Sock5ModeServer) sendReply(c net.Conn, rep uint8) {
reply := []byte{
5,
rep,
0,
1,
}
localAddr := c.LocalAddr().String()
localHost, localPort, _ := net.SplitHostPort(localAddr)
ipBytes := net.ParseIP(localHost).To4()
nPort, _ := strconv.Atoi(localPort)
reply = append(reply, ipBytes...)
portBytes := make([]byte, 2)
binary.BigEndian.PutUint16(portBytes, uint16(nPort))
reply = append(reply, portBytes...)
c.Write(reply)
}
func (s *Sock5ModeServer) doConnect(c net.Conn, command uint8) (proxyConn *Conn, err error) {
addrType := make([]byte, 1)
c.Read(addrType)
var host string
switch addrType[0] {
case ipV4:
ipv4 := make(net.IP, net.IPv4len)
c.Read(ipv4)
host = ipv4.String()
case ipV6:
ipv6 := make(net.IP, net.IPv6len)
c.Read(ipv6)
host = ipv6.String()
case domainName:
var domainLen uint8
binary.Read(c, binary.BigEndian, &domainLen)
domain := make([]byte, domainLen)
c.Read(domain)
host = string(domain)
default:
s.sendReply(c, addrTypeNotSupported)
err = errors.New("Address type not supported")
return nil, err
}
var port uint16
binary.Read(c, binary.BigEndian, &port)
// connect to host
addr := net.JoinHostPort(host, strconv.Itoa(int(port)))
//取出一个连接
if len(s.tunnelList) < 10 { //新建通道
go s.newChan()
}
client := <-s.tunnelList
s.sendReply(c, succeeded)
_, err = client.WriteHost(addr)
return client, nil
}
func (s *Sock5ModeServer) handleConnect(c net.Conn) {
proxyConn, err := s.doConnect(c, connectMethod)
if err != nil {
c.Close()
} else {
go io.Copy(c, proxyConn)
go io.Copy(proxyConn, c)
}
}
func (s *Sock5ModeServer) relay(in, out net.Conn) {
if _, err := io.Copy(in, out); err != nil {
log.Println("copy error", err)
}
in.Close() // will trigger an error in the other relay, then call out.Close()
}
// passive mode
func (s *Sock5ModeServer) handleBind(c net.Conn) {
}
func (s *Sock5ModeServer) handleUDP(c net.Conn) {
log.Println("UDP Associate")
/*
+----+------+------+----------+----------+----------+
|RSV | FRAG | ATYP | DST.ADDR | DST.PORT | DATA |
+----+------+------+----------+----------+----------+
| 2 | 1 | 1 | Variable | 2 | Variable |
+----+------+------+----------+----------+----------+
*/
buf := make([]byte, 3)
c.Read(buf)
// relay udp datagram silently, without any notification to the requesting client
if buf[2] != 0 {
// does not support fragmentation, drop it
log.Println("does not support fragmentation, drop")
dummy := make([]byte, maxUDPPacketSize)
c.Read(dummy)
}
proxyConn, err := s.doConnect(c, associateMethod)
if err != nil {
c.Close()
} else {
go io.Copy(c, proxyConn)
go io.Copy(proxyConn, c)
}
}
func (s *Sock5ModeServer) handleNewConn(c net.Conn) {
buf := make([]byte, 2)
if _, err := io.ReadFull(c, buf); err != nil {
log.Println("negotiation err", err)
c.Close()
return
}
if version := buf[0]; version != 5 {
log.Println("only support socks5, request from: ", c.RemoteAddr())
c.Close()
return
}
nMethods := buf[1]
methods := make([]byte, nMethods)
if len, err := c.Read(methods); len != int(nMethods) || err != nil {
log.Println("wrong method")
c.Close()
return
}
// no authentication required for now
buf[1] = 0
// send a METHOD selection message
c.Write(buf)
s.handleRequest(c)
}
func (s *Sock5ModeServer) Start() {
l, err := net.Listen("tcp", ":"+strconv.Itoa(s.httpPort))
if err != nil {
log.Fatal("listen error: ", err)
}
s.StartTunnel()
for {
conn, err := l.Accept()
if err != nil {
log.Fatal("accept error: ", err)
}
go s.handleNewConn(conn)
}
}
func NewSock5ModeServer(tcpPort, httpPort int) *Sock5ModeServer {
s := new(Sock5ModeServer)
s.tunnelPort = tcpPort
s.httpPort = httpPort
s.tunnelList = make(chan *Conn, 1000)
s.signalList = make(chan *Conn, 10)
return s
}

97
tunnel.go Normal file
View File

@ -0,0 +1,97 @@
package main
import (
"bytes"
"errors"
"fmt"
"log"
"net"
"sync"
"time"
)
type Tunnel struct {
tunnelPort int //通信隧道端口
listener *net.TCPListener //server端监听
signalList chan *Conn //通信
tunnelList chan *Conn //隧道
sync.RWMutex
}
func (s *Tunnel) StartTunnel() error {
var err error
s.listener, err = net.ListenTCP("tcp", &net.TCPAddr{net.ParseIP("0.0.0.0"), s.tunnelPort, ""})
if err != nil {
return err
}
go s.tunnelProcess()
return nil
}
//tcp server
func (s *Tunnel) tunnelProcess() error {
var err error
for {
conn, err := s.listener.Accept()
if err != nil {
log.Println(err)
continue
}
go s.cliProcess(NewConn(conn))
}
return err
}
//验证失败返回错误验证flag并且关闭连接
func (s *Tunnel) verifyError(c *Conn) {
c.conn.Write([]byte(VERIFY_EER))
c.conn.Close()
}
func (s *Tunnel) cliProcess(c *Conn) error {
c.conn.(*net.TCPConn).SetReadDeadline(time.Now().Add(time.Duration(5) * time.Second))
vval := make([]byte, 20)
_, err := c.conn.Read(vval)
if err != nil {
log.Println("客户端读超时。客户端地址为::", c.conn.RemoteAddr())
c.conn.Close()
return err
}
if bytes.Compare(vval, getverifyval()[:]) != 0 {
log.Println("当前客户端连接校验错误,关闭此客户端:", c.conn.RemoteAddr())
s.verifyError(c)
return err
}
//做一个判断 添加到对应的channel里面以供使用
flag, err := c.ReadFlag()
if err != nil {
return err
}
return s.typeDeal(flag, c)
}
//tcp连接类型区分
func (s *Tunnel) typeDeal(typeVal string, c *Conn) error {
switch typeVal {
case WORK_MAIN:
s.signalList <- c
case WORK_CHAN:
s.tunnelList <- c
default:
return errors.New("无法识别")
}
c.SetAlive()
return nil
}
//新建隧道
func (s *Tunnel) newChan() {
retry:
connPass := <-s.signalList
_, err := connPass.conn.Write([]byte("chan"))
if err != nil {
fmt.Println(err)
goto retry
}
s.signalList <- connPass
}

148
util.go Normal file
View File

@ -0,0 +1,148 @@
package main
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/binary"
"errors"
"fmt"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
)
var (
disabledRedirect = errors.New("disabled redirect.")
)
func BadRequest(w http.ResponseWriter) {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
}
//发送请求并转为bytes
func GetEncodeResponse(req *http.Request) ([]byte, error) {
var respBytes []byte
client := new(http.Client)
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return disabledRedirect
}
resp, err := client.Do(req)
disRedirect := err != nil && strings.Contains(err.Error(), disabledRedirect.Error())
if err != nil && !disRedirect {
return respBytes, err
}
if !disRedirect {
defer resp.Body.Close()
} else {
resp.Body = nil
resp.ContentLength = 0
}
respBytes, err = EncodeResponse(resp)
return respBytes, nil
}
// 将request 的处理
func EncodeRequest(r *http.Request) ([]byte, error) {
raw := bytes.NewBuffer([]byte{})
// 写签名
binary.Write(raw, binary.LittleEndian, []byte("sign"))
reqBytes, err := httputil.DumpRequest(r, true)
if err != nil {
return nil, err
}
// 写body数据长度 + 1
binary.Write(raw, binary.LittleEndian, int32(len(reqBytes)+1))
// 判断是否为http或者https的标识1字节
binary.Write(raw, binary.LittleEndian, bool(r.URL.Scheme == "https"))
if err := binary.Write(raw, binary.LittleEndian, reqBytes); err != nil {
return nil, err
}
return raw.Bytes(), nil
}
// 将字节转为request
func DecodeRequest(data []byte) (*http.Request, error) {
if len(data) <= 100 {
return nil, errors.New("待解码的字节长度太小")
}
req, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(data[1:])))
if err != nil {
return nil, err
}
str := strings.Split(req.Host, ":")
req.Host, err = getHost(str[0])
if err != nil {
return nil, err
}
scheme := "http"
if data[0] == 1 {
scheme = "https"
}
req.URL, _ = url.Parse(fmt.Sprintf("%s://%s%s", scheme, req.Host, req.RequestURI))
req.RequestURI = ""
return req, nil
}
//// 将response转为字节
func EncodeResponse(r *http.Response) ([]byte, error) {
raw := bytes.NewBuffer([]byte{})
binary.Write(raw, binary.LittleEndian, []byte(RES_SIGN))
respBytes, err := httputil.DumpResponse(r, true)
if config.Replace == 1 {
respBytes = replaceHost(respBytes)
}
if err != nil {
return nil, err
}
var buf bytes.Buffer
zw := gzip.NewWriter(&buf)
zw.Write(respBytes)
zw.Close()
binary.Write(raw, binary.LittleEndian, int32(len(buf.Bytes())))
if err := binary.Write(raw, binary.LittleEndian, buf.Bytes()); err != nil {
fmt.Println(err)
return nil, err
}
return raw.Bytes(), nil
}
// 将字节转为response
func DecodeResponse(data []byte) (*http.Response, error) {
zr, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}
defer zr.Close()
resp, err := http.ReadResponse(bufio.NewReader(zr), nil)
if err != nil {
return nil, err
}
return resp, nil
}
func getHost(str string) (string, error) {
for _, v := range config.SiteList {
if v.Host == str {
return v.Url + ":" + strconv.Itoa(v.Port), nil
}
}
return "", errors.New("没有找到解析的的host!")
}
func replaceHost(resp []byte) []byte {
str := string(resp)
for _, v := range config.SiteList {
str = strings.Replace(str, v.Url+":"+strconv.Itoa(v.Port), v.Host, -1)
str = strings.Replace(str, v.Url, v.Host, -1)
}
return []byte(str)
}