From 2d0b10e080b6951c5d5c1408d043d6b7cc69a98e Mon Sep 17 00:00:00 2001 From: Nathan Rijksen Date: Fri, 5 Jul 2024 14:35:48 -0700 Subject: [PATCH] Implement powershell support --- internal/assets/contents/shells/pwsh.ps1 | 33 ++++ .../assets/contents/shells/pwsh_global.ps1 | 17 ++ internal/subshell/pwsh/pwsh.go | 169 ++++++++++++++++++ internal/subshell/sscommon/rcfile.go | 24 ++- internal/subshell/sscommon/sscommon.go | 6 +- internal/subshell/subshell.go | 34 +++- internal/testhelpers/e2e/spawn.go | 9 +- test/integration/shell_int_test.go | 28 +++ tmp/test.go | 1 + 9 files changed, 310 insertions(+), 11 deletions(-) create mode 100644 internal/assets/contents/shells/pwsh.ps1 create mode 100644 internal/assets/contents/shells/pwsh_global.ps1 create mode 100644 internal/subshell/pwsh/pwsh.go create mode 100644 tmp/test.go diff --git a/internal/assets/contents/shells/pwsh.ps1 b/internal/assets/contents/shells/pwsh.ps1 new file mode 100644 index 0000000000..8029c280b1 --- /dev/null +++ b/internal/assets/contents/shells/pwsh.ps1 @@ -0,0 +1,33 @@ +{{if and (ne .Project "") (not .PreservePs1) }} +$prevPrompt = $ExecutionContext.SessionState.PSVariable.GetValue('prompt') +if ($prevPrompt -eq $null) { + $prevPrompt = "PS $PWD> " +} +function prompt { + "[{{.Project}}] $prevPrompt" +} +{{end}} + +cd "{{.WD}}" + +{{- range $K, $V := .Env}} +{{- if eq $K "PATH"}} +$env:PATH = "{{ escapePwsh $V}};$env:PATH" +{{- else}} +$env:{{$K}} = "{{ escapePwsh $V }}" +{{- end}} +{{- end}} + +{{ if .ExecAlias }} +New-Alias {{.ExecAlias}} {{.ExecName}} +{{ end }} + +{{range $K, $CMD := .Scripts}} +function {{$K}} { + & {{$.ExecName}} run {{$CMD}} $args +} +{{end}} + +echo "{{ escapePwsh .ActivatedMessage}}" + +{{.UserScripts}} diff --git a/internal/assets/contents/shells/pwsh_global.ps1 b/internal/assets/contents/shells/pwsh_global.ps1 new file mode 100644 index 0000000000..4a44a3a2c2 --- /dev/null +++ b/internal/assets/contents/shells/pwsh_global.ps1 @@ -0,0 +1,17 @@ +{{if and (ne .Project "") (not .PreservePs1) }} +$prevPrompt = $ExecutionContext.SessionState.PSVariable.GetValue('prompt') +if ($prevPrompt -eq $null) { + $prevPrompt = "PS $PWD> " +} +function prompt { + "[{{.Project}}] $prevPrompt" +} +{{end}} + +{{- range $K, $V := .Env}} +{{- if eq $K "PATH"}} +$env:{{$K}} = "{{ escapePwsh $V }};$env:PATH" +{{- else}} +$env:{{$K}} = "{{ escapePwsh $V }}" +{{- end}} +{{- end}} \ No newline at end of file diff --git a/internal/subshell/pwsh/pwsh.go b/internal/subshell/pwsh/pwsh.go new file mode 100644 index 0000000000..f5422be9f3 --- /dev/null +++ b/internal/subshell/pwsh/pwsh.go @@ -0,0 +1,169 @@ +package pwsh + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + + "github.com/ActiveState/cli/internal/errs" + "github.com/ActiveState/cli/internal/fileutils" + "github.com/ActiveState/cli/internal/locale" + "github.com/ActiveState/cli/internal/osutils" + "github.com/ActiveState/cli/internal/osutils/user" + "github.com/ActiveState/cli/internal/output" + "github.com/ActiveState/cli/internal/subshell/cmd" + "github.com/ActiveState/cli/internal/subshell/sscommon" + "github.com/ActiveState/cli/pkg/project" +) + +var escaper *osutils.ShellEscape + +func init() { + escaper = osutils.NewBatchEscaper() +} + +// SubShell covers the subshell.SubShell interface, reference that for documentation +type SubShell struct { + binary string + rcFile *os.File + cmd *exec.Cmd + env map[string]string + errs chan error +} + +const Name string = "powershell" + +// Shell - see subshell.SubShell +func (v *SubShell) Shell() string { + return Name +} + +// Binary - see subshell.SubShell +func (v *SubShell) Binary() string { + return v.binary +} + +// SetBinary - see subshell.SubShell +func (v *SubShell) SetBinary(binary string) { + v.binary = binary +} + +// WriteUserEnv - see subshell.SubShell +func (v *SubShell) WriteUserEnv(cfg sscommon.Configurable, env map[string]string, envType sscommon.RcIdentification, userScope bool) error { + cmdShell := &cmd.SubShell{} + if err := cmdShell.WriteUserEnv(cfg, env, envType, userScope); err != nil { + return errs.Wrap(err, "Forwarded WriteUserEnv call failed") + } + + return nil +} + +func (v *SubShell) CleanUserEnv(cfg sscommon.Configurable, envType sscommon.RcIdentification, userScope bool) error { + cmdShell := &cmd.SubShell{} + if err := cmdShell.CleanUserEnv(cfg, envType, userScope); err != nil { + return errs.Wrap(err, "Forwarded CleanUserEnv call failed") + } + + return nil +} + +func (v *SubShell) RemoveLegacyInstallPath(_ sscommon.Configurable) error { + return nil +} + +func (v *SubShell) WriteCompletionScript(completionScript string) error { + return locale.NewError("err_writecompletions_notsupported", "{{.V0}} does not support completions.", v.Shell()) +} + +func (v *SubShell) RcFile() (string, error) { + home, err := user.HomeDir() + if err != nil { + return "", errs.Wrap(err, "Could not get home dir") + } + + return filepath.Join(home, "Documents", "WindowsPowerShell", "Microsoft.PowerShell_profile.ps1"), nil +} + +func (v *SubShell) EnsureRcFileExists() error { + rcFile, err := v.RcFile() + if err != nil { + return errs.Wrap(err, "Could not determine rc file") + } + + return fileutils.TouchFileUnlessExists(rcFile) +} + +// SetupShellRcFile - subshell.SubShell +func (v *SubShell) SetupShellRcFile(targetDir string, env map[string]string, namespace *project.Namespaced, cfg sscommon.Configurable) error { + env = sscommon.EscapeEnv(env) + return sscommon.SetupShellRcFile(filepath.Join(targetDir, "shell.ps1"), "pwsh_global.ps1", env, namespace, cfg) +} + +// SetEnv - see subshell.SetEnv +func (v *SubShell) SetEnv(env map[string]string) error { + v.env = env + return nil +} + +// Quote - see subshell.Quote +func (v *SubShell) Quote(value string) string { + return escaper.Quote(value) +} + +// Activate - see subshell.SubShell +func (v *SubShell) Activate(prj *project.Project, cfg sscommon.Configurable, out output.Outputer) error { + var shellArgs []string + var directEnv []string + + if prj != nil { + var err error + if v.rcFile, err = sscommon.SetupProjectRcFile(prj, "pwsh.ps1", ".ps1", v.env, out, cfg, false); err != nil { + return err + } + + shellArgs = []string{"-NoExit", "-Command", fmt.Sprintf(". '%s'", v.rcFile.Name())} + } else { + directEnv = sscommon.EnvSlice(v.env) + } + + // powershell -NoExit -Command "& 'C:\Temp\profile.ps1'" + cmd := sscommon.NewCommand(v.binary, shellArgs, directEnv) + v.errs = sscommon.Start(cmd) + v.cmd = cmd + return nil +} + +// Errors returns a channel for receiving errors related to active behavior +func (v *SubShell) Errors() <-chan error { + return v.errs +} + +// Deactivate - see subshell.SubShell +func (v *SubShell) Deactivate() error { + if !v.IsActive() { + return nil + } + + if err := sscommon.Stop(v.cmd); err != nil { + return err + } + + v.cmd = nil + return nil +} + +// Run - see subshell.SubShell +func (v *SubShell) Run(filename string, args ...string) error { + return sscommon.RunFuncByBinary(v.Binary())(osutils.EnvMapToSlice(v.env), filename, args...) +} + +// IsActive - see subshell.SubShell +func (v *SubShell) IsActive() bool { + return v.cmd != nil && (v.cmd.ProcessState == nil || !v.cmd.ProcessState.Exited()) +} + +func (v *SubShell) IsAvailable() bool { + return runtime.GOOS == "windows" +} diff --git a/internal/subshell/sscommon/rcfile.go b/internal/subshell/sscommon/rcfile.go index 14c2c1eef0..0376731a9e 100644 --- a/internal/subshell/sscommon/rcfile.go +++ b/internal/subshell/sscommon/rcfile.go @@ -10,9 +10,10 @@ import ( "strings" "text/template" - "github.com/ActiveState/cli/internal/installation/storage" "github.com/mash/go-tempfile-suffix" + "github.com/ActiveState/cli/internal/installation/storage" + "github.com/ActiveState/cli/internal/assets" "github.com/ActiveState/cli/internal/colorize" "github.com/ActiveState/cli/internal/constants" @@ -252,6 +253,8 @@ func SetupProjectRcFile(prj *project.Project, templateName, ext string, env map[ return nil, errs.Wrap(err, "Failed to read asset") } + logging.Debug("Env: %v", env) + userScripts := "" // Yes this is awkward, issue here - https://www.pivotaltracker.com/story/show/175619373 @@ -332,6 +335,7 @@ func SetupProjectRcFile(prj *project.Project, templateName, ext string, env map[ rcData := map[string]interface{}{ "Owner": prj.Owner(), "Name": prj.Name(), + "Project": prj.NamespaceString(), "Env": actualEnv, "WD": wd, "UserScripts": userScripts, @@ -368,6 +372,22 @@ func SetupProjectRcFile(prj *project.Project, templateName, ext string, env map[ t := template.New("rcfile") t.Funcs(map[string]interface{}{ "splitLines": func(v string) []string { return strings.Split(v, "\n") }, + "escapePwsh": func(v string) string { + // Conver unicode characters + result := "" + for _, char := range v { + if char < 128 { + result += string(char) + } else { + result += fmt.Sprintf("$([char]0x%04x)", char) + } + } + + // Escape special characters + result = strings.ReplaceAll(result, "`", "``") + result = strings.ReplaceAll(result, "\"", "`\"") + return result + }, }) t, err = t.Parse(string(tpl)) @@ -392,8 +412,6 @@ func SetupProjectRcFile(prj *project.Project, templateName, ext string, env map[ return nil, errs.Wrap(err, "Failed to write to output buffer.") } - logging.Debug("Using project RC: (%s) %s", tmpFile.Name(), o.String()) - return tmpFile, nil } diff --git a/internal/subshell/sscommon/sscommon.go b/internal/subshell/sscommon/sscommon.go index 9885d2fa0e..7979ef1504 100644 --- a/internal/subshell/sscommon/sscommon.go +++ b/internal/subshell/sscommon/sscommon.go @@ -84,8 +84,8 @@ func RunFuncByBinary(binary string) RunFunc { switch { case strings.Contains(bin, "bash"): return runWithBash - case strings.Contains(bin, "cmd"): - return runWithCmd + case strings.Contains(bin, "cmd"), strings.Contains(bin, "powershell"): + return runWindowsShell default: return runDirect } @@ -107,7 +107,7 @@ func runWithBash(env []string, name string, args ...string) error { return runDirect(env, "bash", "-c", quotedArgs) } -func runWithCmd(env []string, name string, args ...string) error { +func runWindowsShell(env []string, name string, args ...string) error { ext := filepath.Ext(name) switch ext { case ".py": diff --git a/internal/subshell/subshell.go b/internal/subshell/subshell.go index 0c76098721..b005bf9867 100644 --- a/internal/subshell/subshell.go +++ b/internal/subshell/subshell.go @@ -1,12 +1,15 @@ package subshell import ( + "errors" "os" "os/exec" "path/filepath" "runtime" "strings" + "github.com/shirou/gopsutil/v3/process" + "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/fileutils" "github.com/ActiveState/cli/internal/logging" @@ -17,6 +20,7 @@ import ( "github.com/ActiveState/cli/internal/subshell/bash" "github.com/ActiveState/cli/internal/subshell/cmd" "github.com/ActiveState/cli/internal/subshell/fish" + "github.com/ActiveState/cli/internal/subshell/pwsh" "github.com/ActiveState/cli/internal/subshell/sscommon" "github.com/ActiveState/cli/internal/subshell/tcsh" "github.com/ActiveState/cli/internal/subshell/zsh" @@ -99,6 +103,8 @@ func New(cfg sscommon.Configurable) SubShell { subs = &fish.SubShell{} case cmd.Name: subs = &cmd.SubShell{} + case pwsh.Name: + subs = &pwsh.SubShell{} default: rollbar.Error("subshell.DetectShell did not return a known name: %s", name) switch runtime.GOOS { @@ -113,7 +119,7 @@ func New(cfg sscommon.Configurable) SubShell { logging.Debug("Using binary: %s", path) subs.SetBinary(path) - + err := subs.SetEnv(osutils.EnvSliceToMap(os.Environ())) if err != nil { // We cannot error here, but this error will resurface when activating a runtime, so we can @@ -177,7 +183,7 @@ func DetectShell(cfg sscommon.Configurable) (string, string) { binary = os.Getenv("SHELL") if binary == "" && runtime.GOOS == "windows" { - binary = os.Getenv("ComSpec") + binary = detectShellWindows() } if binary == "" { @@ -204,7 +210,7 @@ func DetectShell(cfg sscommon.Configurable) (string, string) { } isKnownShell := false - for _, ssName := range []string{bash.Name, cmd.Name, fish.Name, tcsh.Name, zsh.Name} { + for _, ssName := range []string{bash.Name, cmd.Name, fish.Name, tcsh.Name, zsh.Name, pwsh.Name} { if name == ssName { isKnownShell = true break @@ -231,3 +237,25 @@ func DetectShell(cfg sscommon.Configurable) (string, string) { return name, path } + +func detectShellWindows() string { + // Windows does not provide a way of identifying which shell we are running in, so we have to look at the parent + // process. + + p, err := process.NewProcess(int32(os.Getppid())) + if err != nil && !errors.As(err, &os.PathError{}) { + panic(err) + } + + for p != nil { + name, err := p.Name() + if err == nil { + if strings.Contains(name, "cmd.exe") || strings.Contains(name, "powershell.exe") { + return name + } + } + p, _ = p.Parent() + } + + return os.Getenv("ComSpec") +} diff --git a/internal/testhelpers/e2e/spawn.go b/internal/testhelpers/e2e/spawn.go index 33aff86cef..bf73ffcd22 100644 --- a/internal/testhelpers/e2e/spawn.go +++ b/internal/testhelpers/e2e/spawn.go @@ -73,8 +73,13 @@ func (s *SpawnedCmd) ExpectInput(opts ...termtest.SetExpectOpt) error { send := `echo $'expect\'input from posix shell'` expect := `expect'input from posix shell` if cmdName != "bash" && shellName != "bash" && runtime.GOOS == "windows" { - send = `echo ^` - expect = `` + if strings.Contains(cmdName, "powershell") || strings.Contains(shellName, "powershell") { + send = "echo \"`\"" + expect = `` + } else { + send = `echo ^` + expect = `` + } } // Termtest internal functions already implement error handling diff --git a/test/integration/shell_int_test.go b/test/integration/shell_int_test.go index 0d60111e24..fdd6519284 100644 --- a/test/integration/shell_int_test.go +++ b/test/integration/shell_int_test.go @@ -476,6 +476,34 @@ events:`, lang, splat), 1) cp.ExpectExit() // exit code varies depending on shell; just assert the shell exited } +func (suite *ShellIntegrationTestSuite) TestWindowsShells() { + if runtime.GOOS != "windows" { + suite.T().Skip("Windows only test") + } + + suite.OnlyRunForTags(tagsuite.Critical, tagsuite.Shell) + ts := e2e.New(suite.T(), false) + defer ts.Close() + + ts.PrepareProject("ActiveState-CLI/Empty", "6d79f2ae-f8b5-46bd-917a-d4b2558ec7b8") + + hostname, err := os.Hostname() + suite.Require().NoError(err) + cp := ts.SpawnCmd("cmd", "/C", "state", "shell") + cp.ExpectInput() + cp.SendLine("hostname") + cp.Expect(hostname) // cmd.exe shows the actual hostname + cp.SendLine("exit") + cp.ExpectExitCode(0) + + cp = ts.SpawnCmd("powershell", "-Command", "state", "shell") + cp.ExpectInput() + cp.SendLine("$host.name") + cp.Expect("ConsoleHost") // powershell always shows ConsoleHost, go figure + cp.SendLine("exit") + cp.ExpectExitCode(0) +} + func TestShellIntegrationTestSuite(t *testing.T) { suite.Run(t, new(ShellIntegrationTestSuite)) } diff --git a/tmp/test.go b/tmp/test.go new file mode 100644 index 0000000000..14aed2621e --- /dev/null +++ b/tmp/test.go @@ -0,0 +1 @@ +package tmp