diff --git a/conf/clients.csv b/conf/clients.csv deleted file mode 100644 index ef7cccd..0000000 --- a/conf/clients.csv +++ /dev/null @@ -1,2 +0,0 @@ -12,ao0yd0jx6ty0ht69,,true,,,0,false,0,0,0,1 -11,mxg22qa06dc137of,,true,,,0,false,0,0,0,1 diff --git a/conf/clients.json b/conf/clients.json new file mode 100644 index 0000000..e69de29 diff --git a/conf/hosts.csv b/conf/hosts.csv deleted file mode 100644 index b48e1c9..0000000 --- a/conf/hosts.csv +++ /dev/null @@ -1 +0,0 @@ -a.o.com,123.206.77.88:8080,11,,,,/,1,0,0,all diff --git a/conf/hosts.json b/conf/hosts.json new file mode 100644 index 0000000..e69de29 diff --git a/conf/tasks.csv b/conf/tasks.csv deleted file mode 100644 index b2cc28f..0000000 --- a/conf/tasks.csv +++ /dev/null @@ -1 +0,0 @@ -9999,tcp,,1,3,11,,0,0,,0.0.0.0 diff --git a/conf/tasks.json b/conf/tasks.json new file mode 100644 index 0000000..e69de29 diff --git a/lib/file/file.go b/lib/file/file.go index d540872..cd47db4 100644 --- a/lib/file/file.go +++ b/lib/file/file.go @@ -1,18 +1,16 @@ package file import ( - "encoding/csv" + "encoding/json" "errors" "fmt" "github.com/cnlh/nps/lib/common" "github.com/cnlh/nps/lib/crypt" "github.com/cnlh/nps/lib/rate" - "github.com/cnlh/nps/vender/github.com/astaxie/beego/logs" "net/http" "os" "path/filepath" "regexp" - "strconv" "strings" "sync" "sync/atomic" @@ -20,7 +18,10 @@ import ( func NewCsv(runPath string) *Csv { return &Csv{ - RunPath: runPath, + RunPath: runPath, + TaskFilePath: filepath.Join(runPath, "conf", "tasks.json"), + HostFilePath: filepath.Join(runPath, "conf", "hosts.json"), + ClientFilePath: filepath.Join(runPath, "conf", "clients.json"), } } @@ -33,96 +34,62 @@ type Csv struct { ClientIncreaseId int32 //客户端id TaskIncreaseId int32 //任务自增ID HostIncreaseId int32 //host increased id -} - -func (s *Csv) StoreTasksToCsv() { - // 创建文件 - csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "tasks.csv")) - if err != nil { - logs.Error(err.Error()) - } - defer csvFile.Close() - writer := csv.NewWriter(csvFile) - s.Tasks.Range(func(key, value interface{}) bool { - task := value.(*Tunnel) - if task.NoStore { - return true - } - record := []string{ - strconv.Itoa(task.Port), - task.Mode, - task.Target.TargetStr, - common.GetStrByBool(task.Status), - strconv.Itoa(task.Id), - strconv.Itoa(task.Client.Id), - task.Remark, - strconv.Itoa(int(task.Flow.ExportFlow)), - strconv.Itoa(int(task.Flow.InletFlow)), - task.Password, - task.ServerIp, - } - err := writer.Write(record) - if err != nil { - logs.Error(err.Error()) - } - return true - }) - writer.Flush() -} - -func (s *Csv) openFile(path string) ([][]string, error) { - // 打开文件 - file, err := os.Open(path) - if err != nil { - panic(err) - } - defer file.Close() - - // 获取csv的reader - reader := csv.NewReader(file) - - // 设置FieldsPerRecord为-1 - reader.FieldsPerRecord = -1 - - // 读取文件中所有行保存到slice中 - return reader.ReadAll() + TaskFilePath string + HostFilePath string + ClientFilePath string } func (s *Csv) LoadTaskFromCsv() { - path := filepath.Join(s.RunPath, "conf", "tasks.csv") - records, err := s.openFile(path) - if err != nil { - logs.Error("Profile Opening Error:", path) - os.Exit(0) - } - // 将每一行数据保存到内存slice中 - for _, item := range records { - post := &Tunnel{ - Port: common.GetIntNoErrByStr(item[0]), - Mode: item[1], - Status: common.GetBoolByStr(item[3]), - Id: common.GetIntNoErrByStr(item[4]), - Remark: item[6], - Password: item[9], + loadSyncMapFromFile(s.TaskFilePath, func(v string) { + var err error + post := new(Tunnel) + if json.Unmarshal([]byte(v), &post) != nil { + return } - post.Target = new(Target) - post.Target.TargetStr = item[2] - post.Flow = new(Flow) - post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[7])) - post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[8])) - if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[5])); err != nil { - continue - } - if len(item) > 10 { - post.ServerIp = item[10] - } else { - post.ServerIp = "0.0.0.0" + if post.Client, err = s.GetClient(post.Client.Id); err != nil { + return } s.Tasks.Store(post.Id, post) if post.Id > int(s.TaskIncreaseId) { - s.TaskIncreaseId = int32(s.TaskIncreaseId) + s.TaskIncreaseId = int32(post.Id) } - } + }) +} + +func (s *Csv) LoadClientFromCsv() { + loadSyncMapFromFile(s.ClientFilePath, func(v string) { + post := new(Client) + if json.Unmarshal([]byte(v), &post) != nil { + return + } + if post.RateLimit > 0 { + post.Rate = rate.NewRate(int64(post.RateLimit * 1024)) + } else { + post.Rate = rate.NewRate(int64(2 << 23)) + } + post.Rate.Start() + s.Clients.Store(post.Id, post) + if post.Id > int(s.ClientIncreaseId) { + s.ClientIncreaseId = int32(post.Id) + } + }) +} + +func (s *Csv) LoadHostFromCsv() { + loadSyncMapFromFile(s.HostFilePath, func(v string) { + var err error + post := new(Host) + if json.Unmarshal([]byte(v), &post) != nil { + return + } + if post.Client, err = s.GetClient(post.Client.Id); err != nil { + return + } + s.Hosts.Store(post.Id, post) + if post.Id > int(s.HostIncreaseId) { + s.HostIncreaseId = int32(post.Id) + } + }) } func (s *Csv) GetIdByVerifyKey(vKey string, addr string) (id int, err error) { @@ -195,135 +162,15 @@ func (s *Csv) GetTask(id int) (t *Tunnel, err error) { } func (s *Csv) StoreHostToCsv() { - // 创建文件 - csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "hosts.csv")) - if err != nil { - panic(err) - } - defer csvFile.Close() - // 获取csv的Writer - writer := csv.NewWriter(csvFile) - // 将map中的Post转换成slice,因为csv的Write需要slice参数 - // 并写入csv文件 - s.Hosts.Range(func(key, value interface{}) bool { - host := value.(*Host) - if host.NoStore { - return true - } - record := []string{ - host.Host, - host.Target.TargetStr, - strconv.Itoa(host.Client.Id), - host.HeaderChange, - host.HostChange, - host.Remark, - host.Location, - strconv.Itoa(host.Id), - strconv.Itoa(int(host.Flow.ExportFlow)), - strconv.Itoa(int(host.Flow.InletFlow)), - host.Scheme, - } - err1 := writer.Write(record) - if err1 != nil { - panic(err1) - } - return true - }) - - // 确保所有内存数据刷到csv文件 - writer.Flush() + storeSyncMapToFile(s.Hosts, s.HostFilePath) } -func (s *Csv) LoadClientFromCsv() { - path := filepath.Join(s.RunPath, "conf", "clients.csv") - records, err := s.openFile(path) - if err != nil { - logs.Error("Profile Opening Error:", path) - os.Exit(0) - } - // 将每一行数据保存到内存slice中 - for _, item := range records { - post := &Client{ - Id: common.GetIntNoErrByStr(item[0]), - VerifyKey: item[1], - Remark: item[2], - Status: common.GetBoolByStr(item[3]), - RateLimit: common.GetIntNoErrByStr(item[8]), - Cnf: &Config{ - U: item[4], - P: item[5], - Crypt: common.GetBoolByStr(item[6]), - Compress: common.GetBoolByStr(item[7]), - }, - MaxConn: common.GetIntNoErrByStr(item[10]), - } - if post.Id > int(s.ClientIncreaseId) { - s.ClientIncreaseId = int32(post.Id) - } - if post.RateLimit > 0 { - post.Rate = rate.NewRate(int64(post.RateLimit * 1024)) - post.Rate.Start() - } else { - post.Rate = rate.NewRate(int64(2 << 23)) - post.Rate.Start() - } - post.Flow = new(Flow) - post.Flow.FlowLimit = int64(common.GetIntNoErrByStr(item[9])) - if len(item) >= 12 { - post.ConfigConnAllow = common.GetBoolByStr(item[11]) - } else { - post.ConfigConnAllow = true - } - if len(item) >= 13 { - post.WebUserName = item[12] - } else { - post.WebUserName = "" - } - if len(item) >= 14 { - post.WebPassword = item[13] - } else { - post.WebPassword = "" - } - s.Clients.Store(post.Id, post) - } +func (s *Csv) StoreTasksToCsv() { + storeSyncMapToFile(s.Tasks, s.TaskFilePath) } -func (s *Csv) LoadHostFromCsv() { - path := filepath.Join(s.RunPath, "conf", "hosts.csv") - records, err := s.openFile(path) - if err != nil { - logs.Error("Profile Opening Error:", path) - os.Exit(0) - } - // 将每一行数据保存到内存slice中 - for _, item := range records { - post := &Host{ - Host: item[0], - HeaderChange: item[3], - HostChange: item[4], - Remark: item[5], - Location: item[6], - Id: common.GetIntNoErrByStr(item[7]), - } - if post.Client, err = s.GetClient(common.GetIntNoErrByStr(item[2])); err != nil { - continue - } - post.Target = new(Target) - post.Target.TargetStr = item[1] - post.Flow = new(Flow) - post.Flow.ExportFlow = int64(common.GetIntNoErrByStr(item[8])) - post.Flow.InletFlow = int64(common.GetIntNoErrByStr(item[9])) - if len(item) > 10 { - post.Scheme = item[10] - } else { - post.Scheme = "all" - } - s.Hosts.Store(post.Id, post) - if post.Id > int(s.HostIncreaseId) { - s.HostIncreaseId = int32(post.Id) - } - //store host to hostMap if the host url is none - } +func (s *Csv) StoreClientsToCsv() { + storeSyncMapToFile(s.Clients, s.ClientFilePath) } func (s *Csv) DelHost(id int) error { @@ -439,6 +286,7 @@ func (s *Csv) VerifyVkey(vkey string, id int) (res bool) { }) return res } + func (s *Csv) VerifyUserName(username string, id int) (res bool) { res = true s.Clients.Range(func(key, value interface{}) bool { @@ -452,18 +300,6 @@ func (s *Csv) VerifyUserName(username string, id int) (res bool) { return res } -func (s *Csv) GetClientId() int32 { - return atomic.AddInt32(&s.ClientIncreaseId, 1) -} - -func (s *Csv) GetTaskId() int32 { - return atomic.AddInt32(&s.TaskIncreaseId, 1) -} - -func (s *Csv) GetHostId() int32 { - return atomic.AddInt32(&s.HostIncreaseId, 1) -} - func (s *Csv) UpdateClient(t *Client) error { s.Clients.Store(t.Id, t) if t.RateLimit == 0 { @@ -516,6 +352,7 @@ func (s *Csv) GetClient(id int) (c *Client, err error) { err = errors.New("未找到客户端") return } + func (s *Csv) GetClientIdByVkey(vkey string) (id int, err error) { var exist bool s.Clients.Range(func(key, value interface{}) bool { @@ -585,40 +422,70 @@ func (s *Csv) GetInfoByHost(host string, r *http.Request) (h *Host, err error) { return } -func (s *Csv) StoreClientsToCsv() { - // 创建文件 - csvFile, err := os.Create(filepath.Join(s.RunPath, "conf", "clients.csv")) +func (s *Csv) GetClientId() int32 { + return atomic.AddInt32(&s.ClientIncreaseId, 1) +} + +func (s *Csv) GetTaskId() int32 { + return atomic.AddInt32(&s.TaskIncreaseId, 1) +} + +func (s *Csv) GetHostId() int32 { + return atomic.AddInt32(&s.HostIncreaseId, 1) +} + +func loadSyncMapFromFile(filePath string, f func(value string)) { + b, err := common.ReadAllFromFile(filePath) if err != nil { - logs.Error(err.Error()) + panic(err) } - defer csvFile.Close() - writer := csv.NewWriter(csvFile) - s.Clients.Range(func(key, value interface{}) bool { - client := value.(*Client) - if client.NoStore { + for _, v := range strings.Split(string(b), "\n"+common.CONN_DATA_SEQ) { + f(v) + } +} + +func storeSyncMapToFile(m sync.Map, filePath string) { + file, err := os.Create(filePath) + if err != nil { + panic(err) + } + defer file.Close() + m.Range(func(key, value interface{}) bool { + var b []byte + var err error + switch value.(type) { + case *Tunnel: + obj := value.(*Tunnel) + if obj.NoStore { + return true + } + b, err = json.Marshal(obj) + case *Host: + obj := value.(*Host) + if obj.NoStore { + return true + } + b, err = json.Marshal(obj) + case *Client: + obj := value.(*Client) + if obj.NoStore { + return true + } + b, err = json.Marshal(obj) + default: return true } - record := []string{ - strconv.Itoa(client.Id), - client.VerifyKey, - client.Remark, - strconv.FormatBool(client.Status), - client.Cnf.U, - client.Cnf.P, - common.GetStrByBool(client.Cnf.Crypt), - strconv.FormatBool(client.Cnf.Compress), - strconv.Itoa(client.RateLimit), - strconv.Itoa(int(client.Flow.FlowLimit)), - strconv.Itoa(int(client.MaxConn)), - common.GetStrByBool(client.ConfigConnAllow), - client.WebUserName, - client.WebPassword, - } - err := writer.Write(record) if err != nil { - logs.Error(err.Error()) + return true + } + _, err = file.Write(b) + if err != nil { + panic(err) + } + _, err = file.Write([]byte("\n" + common.CONN_DATA_SEQ)) + if err != nil { + panic(err) } return true }) - writer.Flush() } diff --git a/lib/file/obj.go b/lib/file/obj.go index 9a819a0..80f6b9b 100644 --- a/lib/file/obj.go +++ b/lib/file/obj.go @@ -183,6 +183,6 @@ type Host struct { Flow *Flow Client *Client Target *Target //目标 - Health + Health `json:"-"` sync.RWMutex } diff --git a/lib/mux/mux.go b/lib/mux/mux.go index fbbebf2..315bc68 100644 --- a/lib/mux/mux.go +++ b/lib/mux/mux.go @@ -65,7 +65,7 @@ func (s *Mux) NewConn() (*conn, error) { return nil, err } //set a timer timeout 30 second - timer := time.NewTimer(time.Second * 30) + timer := time.NewTimer(time.Minute * 2) defer timer.Stop() select { case <-conn.connStatusOkCh: