diff --git a/common/os.go b/common/os.go index e0a00926..8af6cd22 100644 --- a/common/os.go +++ b/common/os.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "os" + "os/exec" "os/signal" "path/filepath" "strings" @@ -13,9 +14,22 @@ import ( ) var ( - GoPath = os.Getenv("GOPATH") + GoPath = gopath() ) +func gopath() string { + path := os.Getenv("GOPATH") + if len(path) == 0 { + goCmd := exec.Command("go", "env", "GOPATH") + out, err := goCmd.Output() + if err != nil { + panic(fmt.Sprintf("failed to determine gopath: %v", err)) + } + path = string(out) + } + return path +} + func TrapSignal(cb func()) { c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt, syscall.SIGTERM) diff --git a/common/os_test.go b/common/os_test.go index 05359e36..168eb438 100644 --- a/common/os_test.go +++ b/common/os_test.go @@ -27,3 +27,26 @@ func TestWriteFileAtomic(t *testing.T) { t.Fatal(err) } } + +func TestGoPath(t *testing.T) { + // restore original gopath upon exit + path := os.Getenv("GOPATH") + defer func() { + _ = os.Setenv("GOPATH", path) + }() + + err := os.Setenv("GOPATH", "~/testgopath") + if err != nil { + t.Fatal(err) + } + path = gopath() + if path != "~/testgopath" { + t.Fatalf("gopath should return GOPATH env var if set, got %v", path) + } + os.Unsetenv("GOPATH") + + path = gopath() + if path == "~/testgopath" || path == "" { + t.Fatalf("gopath should return go env GOPATH result if env var does not exist, got %v", path) + } +}