From ecc23e67032fd8fa55b769dd1b28d9ce1da60464 Mon Sep 17 00:00:00 2001 From: YangXin <245051644@qq.com> Date: Mon, 12 Dec 2022 06:46:36 +0000 Subject: [PATCH] Add whitelist to rexec_server Signed-off-by: YangXin <245051644@qq.com> --- qtfs/rexec/client.go | 4 +++- qtfs/rexec/common.go | 29 +++++++++++++++++++++++++++++ qtfs/rexec/server.go | 38 ++++++++++++++++++++++++++++++++++++-- 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/qtfs/rexec/client.go b/qtfs/rexec/client.go index 13b63f5..dc1af8b 100644 --- a/qtfs/rexec/client.go +++ b/qtfs/rexec/client.go @@ -156,7 +156,6 @@ func main() { retryCnt := 3 // 1. get pid from response - time.Sleep(5 * time.Millisecond) response := &CommandResponse{} retry: err = receiver.Receive(response) @@ -168,6 +167,9 @@ retry: } log.Fatal(err) } + if (response.WhiteList == 0) { + log.Fatalf("%s command in White List of rexec server\n", command.Cmd) + } pid := response.Pid lpid := os.Getpid() log.Printf("create pidFile for %d:%d\n", pid, lpid) diff --git a/qtfs/rexec/common.go b/qtfs/rexec/common.go index 9ce21c4..b59b12b 100644 --- a/qtfs/rexec/common.go +++ b/qtfs/rexec/common.go @@ -8,6 +8,7 @@ import ( "os" "strconv" "strings" + "syscall" "io/ioutil" "encoding/json" @@ -30,10 +31,34 @@ type RemoteCommand struct { Cgroups map[string]string } +func CheckRight(fileName string) error { + var uid int + var gid int + var mode int + var stat syscall.Stat_t + if err := syscall.Stat(fileName, &stat); err != nil { + return fmt.Errorf("Can't get status of %s: %s\n", fileName, err) + } + uid = int(stat.Uid) + gid = int(stat.Gid) + mode = int(stat.Mode) + + if (uid != 0 || gid != 0) { + return fmt.Errorf("Owner of %s must be root\n", fileName) + } + + if (mode & 0777 != 0400) { + return fmt.Errorf("Mode of %s must be 0400\n", fileName) + } + + return nil +} + // CommandResponse is the returned response object from the remote execution type CommandResponse struct { Pid int Status int + WhiteList int } // NetAddr is struct to describe net proto and addr @@ -90,6 +115,10 @@ func parseUnixAddr(inAddr string) (NetAddr, error) { func readAddrFromFile(role string) (string) { fileName := fmt.Sprintf("%s/%s.json", configDir, role) + if err := CheckRight(fileName); err != nil { + fmt.Printf("Check right of %s failed: %s", fileName, err) + return "" + } file, err := ioutil.ReadFile(fileName) if err != nil { fmt.Printf("read %s failed: %s", fileName, err) diff --git a/qtfs/rexec/server.go b/qtfs/rexec/server.go index 4559b79..de3f6cf 100644 --- a/qtfs/rexec/server.go +++ b/qtfs/rexec/server.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "fmt" "io" + "io/ioutil" "log" "net" "os" @@ -17,13 +18,33 @@ import ( const ( role = "server" + whiteList = "whitelist" ) +var WhiteLists map[string] int +func getWhitelist() error { + fileName := fmt.Sprintf("%s/%s", configDir, whiteList) + if err := CheckRight(fileName); err != nil { + log.Fatal(err) + } + file, err := ioutil.ReadFile(fileName) + if err != nil { + fmt.Printf("read %s failed: %s", fileName, err) + return err + } + fileContent := string(file) + lines := strings.Split(fileContent, "\n") + for i, v := range lines { + WhiteLists[v] = i + } + return nil +} func getHost(addr string) string { return strings.Split(addr, ":")[0] } func main() { + WhiteLists = make(map[string]int, 10) cert := os.Getenv("TLS_CERT") key := os.Getenv("TLS_KEY") @@ -32,6 +53,10 @@ func main() { if err != nil { log.Fatal(err) } + if err := getWhitelist(); err != nil { + log.Println("Get Whitelist failed") + return + } if cert != "" && key != "" { tlsCert, err := tls.LoadX509KeyPair(cert, key) if err != nil { @@ -86,13 +111,23 @@ func main() { } command := &RemoteCommand{} + returnResult := &CommandResponse{} + returnResult.WhiteList = 1 err = receiver.Receive(command) if err != nil { log.Print(err) return } log.Printf("cmd(%s), args(%v)\n", command.Cmd, command.Args) - + if _, ok := WhiteLists[command.Cmd]; !ok { + log.Printf("%s not in WhiteLists", command.Cmd) + returnResult.WhiteList = 0 + err = command.StatusChan.Send(returnResult) + if err != nil { + log.Print(err) + } + return + } cmd := exec.Command(command.Cmd, command.Args...) cmd.Stdout = command.Stdout cmd.Stderr = command.Stderr @@ -111,7 +146,6 @@ func main() { defer command.Stdout.Close() defer command.Stderr.Close() - returnResult := &CommandResponse{} err = cmd.Start() if err != nil { // send return status back -- Gitee