From c3a50288dfdb2d46d0d3beebfb7a7e26959a50d8 Mon Sep 17 00:00:00 2001 From: Nils Lagerkvist Date: Sat, 5 Dec 2015 14:18:29 +0100 Subject: [PATCH] Added a sh module for better shell access --- luasrc/lua-generated.go | 223 -------------------------- luasrc/sh.lua | 186 ---------------------- luasrc/shell.lua | 35 ----- sh/sh.go | 53 +++++++ sh/sh_test.go | 271 +++++++++++++++++++++++++++++++ sh/shellcommand.go | 341 ++++++++++++++++++++++++++++++++++++++++ sh/stderr.test.sh | 3 + state.go | 7 +- string.go | 20 +++ 9 files changed, 690 insertions(+), 449 deletions(-) delete mode 100644 luasrc/sh.lua delete mode 100644 luasrc/shell.lua create mode 100644 sh/sh.go create mode 100644 sh/sh_test.go create mode 100644 sh/shellcommand.go create mode 100755 sh/stderr.test.sh diff --git a/luasrc/lua-generated.go b/luasrc/lua-generated.go index 03f5dbc..6eb81f1 100644 --- a/luasrc/lua-generated.go +++ b/luasrc/lua-generated.go @@ -1,227 +1,4 @@ package luasrc const ( -Sh = `--[[ -The MIT License (MIT) - -Copyright (c) 2015 Serge Zaitsev - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -]] -local M = {} -local abort = true - -local function copy(obj, seen) - if type(obj) ~= 'table' then return obj end - if seen and seen[obj] then return seen[obj] end - local s = seen or {} - local res = setmetatable({}, getmetatable(obj)) - s[obj] = res - for k, v in pairs(obj) do res[copy(k, s)] = copy(v, s) end - return res -end - --- converts key and it's argument to "-k" or "-k=v" or just "" -local function arg(k, a) - if not a then return k end - if type(a) == 'string' and #a > 0 then return k..'='..a end - if type(a) == 'number' then return k..'='..tostring(a) end - if type(a) == 'boolean' and a == true then return k end - if type(a) == 'function' then return "" end - error('invalid argument type', type(a), a) -end - --- converts nested tables into a flat list of arguments and concatenated input -local function flatten(t) - local result = {args = {}, input = ''} - - local function f(t) - local keys = {} - for k, v in ipairs(t) do - keys[k] = true - if type(v) == 'table' then - f(v) - else - table.insert(result.args, v) - end - end - for k, v in pairs(t) do - if k == '__input' then - result.input = result.input .. v - elseif k == 'cmd' then - elseif k == 'stdout' then - elseif k == 'stderr' then - elseif k == 'exitcode' then - elseif not keys[k] and k:sub(1, 1) ~= '_' then - local key = '-'..k - if #k > 1 then key = '-' ..key end - table.insert(result.args, arg(key, v)) - end - end - end - - f(t) - return result -end - --- returns a function that executes the command with given args and returns its --- output, exit status etc -local function command(cmd, ...) - local prearg = {...} - return function(...) - local args = flatten({...}) - local s = cmd - for _, v in ipairs(prearg) do - s = s .. ' ' .. v - end - for k, v in pairs(args.args) do - s = s .. ' ' .. v - end - - if args.input then - local f = io.open(M.tmpfile, 'w') - f:write(args.input) - f:close() - s = s .. ' <'..M.tmpfile - end - print("cmd", s) - - local exit, output, stderr = blade.system(s) - os.remove(M.tmpfile) - - local t = { - __input = output, - cmd = cmd, - stdout = output, - stderr = stderr, - exitcode = exit, - print = function(self) - io.write(self.__input) - return self - end, - lines = function(self) - s = tostring(self.__input) - if s:sub(-1)~="\n" then s=s.."\n" end - return s:gmatch("(.-)\n") - end, - } - local mt = { - __index = function(self, k, ...) - if self.exitcode ~= 0 then - if abort then - os.exit(self.exitcode) - end - return function() - return self - end - end - return command(k) - end, - __tostring = function(self) - -- return trimmed command output as a string - return self.__input:match('^%s*(.-)%s*$') - end - } - return setmetatable(t, mt) - end -end - --- creates sub commands -local function subcommand(...) - local prearg = {...} - return setmetatable({}, { - __call = function(_, cmd, ...) - local foo = copy(prearg) - table.insert(foo, cmd) - return command(unpack(foo))(...) - end, - __index = function(t, cmd) - local foo = copy(prearg) - table.insert(foo, cmd) - return command(unpack(foo)) - end - }) -end - --- get global metatable -local mt = getmetatable(_G) -if mt == nil then - mt = {} - setmetatable(_G, mt) -end - --- export command() function and configurable temporary "input" file -M.command = command -M.abort = function(abrt) - abort = abrt -end -M.subcommand = subcommand -M.tmpfile = '/tmp/shluainput' -M.git = subcommand('git') -M.sudo = subcommand('sudo') - --- allow to call sh to run shell commands -setmetatable(M, { - __call = function(_, cmd, ...) - return command(cmd)(...) - end, - __index = function(t, cmd) - return command(cmd) - end -}) - -return M -` -Shell = `local sh = require('sh') - -local M = {} - --- get global metatable -local mt = getmetatable(_G) -if mt == nil then - mt = {} - setmetatable(_G, mt) -end - --- set hook for undefined variables -mt.__index = function(t, cmd) - return sh.command(cmd) -end - -git = sh.git -sudo = sh.sudo - --- export command() function and configurable temporary "input" file -M.command = sh.command -M.subcommand = sh.subcommand -M.tmpfile = '/tmp/shluainput' - --- allow to call sh to run shell commands -setmetatable(M, { - __call = function(_, cmd, ...) - return sh.command(cmd, ...) - end, - __index = function(t, cmd) - return sh.command(cmd) - end -}) - -return M -` ) diff --git a/luasrc/sh.lua b/luasrc/sh.lua deleted file mode 100644 index 1821558..0000000 --- a/luasrc/sh.lua +++ /dev/null @@ -1,186 +0,0 @@ ---[[ -The MIT License (MIT) - -Copyright (c) 2015 Serge Zaitsev - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. -]] -local M = {} -local abort = true - -local function copy(obj, seen) - if type(obj) ~= 'table' then return obj end - if seen and seen[obj] then return seen[obj] end - local s = seen or {} - local res = setmetatable({}, getmetatable(obj)) - s[obj] = res - for k, v in pairs(obj) do res[copy(k, s)] = copy(v, s) end - return res -end - --- converts key and it's argument to "-k" or "-k=v" or just "" -local function arg(k, a) - if not a then return k end - if type(a) == 'string' and #a > 0 then return k..'='..a end - if type(a) == 'number' then return k..'='..tostring(a) end - if type(a) == 'boolean' and a == true then return k end - if type(a) == 'function' then return "" end - error('invalid argument type', type(a), a) -end - --- converts nested tables into a flat list of arguments and concatenated input -local function flatten(t) - local result = {args = {}, input = ''} - - local function f(t) - local keys = {} - for k, v in ipairs(t) do - keys[k] = true - if type(v) == 'table' then - f(v) - else - table.insert(result.args, v) - end - end - for k, v in pairs(t) do - if k == '__input' then - result.input = result.input .. v - elseif k == 'cmd' then - elseif k == 'stdout' then - elseif k == 'stderr' then - elseif k == 'exitcode' then - elseif not keys[k] and k:sub(1, 1) ~= '_' then - local key = '-'..k - if #k > 1 then key = '-' ..key end - table.insert(result.args, arg(key, v)) - end - end - end - - f(t) - return result -end - --- returns a function that executes the command with given args and returns its --- output, exit status etc -local function command(cmd, ...) - local prearg = {...} - return function(...) - local args = flatten({...}) - local s = cmd - for _, v in ipairs(prearg) do - s = s .. ' ' .. v - end - for k, v in pairs(args.args) do - s = s .. ' ' .. v - end - - if args.input then - local f = io.open(M.tmpfile, 'w') - f:write(args.input) - f:close() - s = s .. ' <'..M.tmpfile - end - print("cmd", s) - - local exit, output, stderr = blade.system(s) - os.remove(M.tmpfile) - - local t = { - __input = output, - cmd = cmd, - stdout = output, - stderr = stderr, - exitcode = exit, - print = function(self) - io.write(self.__input) - return self - end, - lines = function(self) - s = tostring(self.__input) - if s:sub(-1)~="\n" then s=s.."\n" end - return s:gmatch("(.-)\n") - end, - } - local mt = { - __index = function(self, k, ...) - if self.exitcode ~= 0 then - if abort then - os.exit(self.exitcode) - end - return function() - return self - end - end - return command(k) - end, - __tostring = function(self) - -- return trimmed command output as a string - return self.__input:match('^%s*(.-)%s*$') - end - } - return setmetatable(t, mt) - end -end - --- creates sub commands -local function subcommand(...) - local prearg = {...} - return setmetatable({}, { - __call = function(_, cmd, ...) - local foo = copy(prearg) - table.insert(foo, cmd) - return command(unpack(foo))(...) - end, - __index = function(t, cmd) - local foo = copy(prearg) - table.insert(foo, cmd) - return command(unpack(foo)) - end - }) -end - --- get global metatable -local mt = getmetatable(_G) -if mt == nil then - mt = {} - setmetatable(_G, mt) -end - --- export command() function and configurable temporary "input" file -M.command = command -M.abort = function(abrt) - abort = abrt -end -M.subcommand = subcommand -M.tmpfile = '/tmp/shluainput' -M.git = subcommand('git') -M.sudo = subcommand('sudo') - --- allow to call sh to run shell commands -setmetatable(M, { - __call = function(_, cmd, ...) - return command(cmd)(...) - end, - __index = function(t, cmd) - return command(cmd) - end -}) - -return M diff --git a/luasrc/shell.lua b/luasrc/shell.lua deleted file mode 100644 index 99c3a4d..0000000 --- a/luasrc/shell.lua +++ /dev/null @@ -1,35 +0,0 @@ -local sh = require('sh') - -local M = {} - --- get global metatable -local mt = getmetatable(_G) -if mt == nil then - mt = {} - setmetatable(_G, mt) -end - --- set hook for undefined variables -mt.__index = function(t, cmd) - return sh.command(cmd) -end - -git = sh.git -sudo = sh.sudo - --- export command() function and configurable temporary "input" file -M.command = sh.command -M.subcommand = sh.subcommand -M.tmpfile = '/tmp/shluainput' - --- allow to call sh to run shell commands -setmetatable(M, { - __call = function(_, cmd, ...) - return sh.command(cmd, ...) - end, - __index = function(t, cmd) - return sh.command(cmd) - end -}) - -return M diff --git a/sh/sh.go b/sh/sh.go new file mode 100644 index 0000000..4b7b0ee --- /dev/null +++ b/sh/sh.go @@ -0,0 +1,53 @@ +package sh + +import "github.com/yuin/gopher-lua" + +var exports = map[string]lua.LGFunction{} + +// Loader is used for preloading a module +func Loader(L *lua.LState) int { + + // register functions to the table + mod := L.SetFuncs(L.NewTable(), exports) + + // set up meta table + mt := L.NewTable() + L.SetField(mt, "__index", L.NewClosure(moduleIndex)) + L.SetField(mt, "__call", L.NewClosure(moduleCall)) + L.SetMetatable(mod, mt) + + shMetaTable := L.NewTypeMetatable(luaShTypeName) + L.SetField(shMetaTable, "__call", L.NewFunction(shCall)) + L.SetField(shMetaTable, "__index", L.NewFunction(shIndex)) + + // returns the module + L.Push(mod) + return 1 +} + +// moduleIndex creates and returns userdata shell command (sh) defined by the +// index. +func moduleIndex(L *lua.LState) int { + index := L.CheckString(2) + + cmd := &shellCommand{ + path: index, + } + + L.Push(cmd.UserData(L)) + return 1 +} + +func moduleCall(L *lua.LState) int { + path := L.CheckString(2) + args := checkStrings(L, 3) + + cmd, err := newShellCommand(path, args...) + checkError(L, err) + + err = cmd.command.Start() + checkError(L, err) + + L.Push(cmd.UserData(L)) + return 1 +} diff --git a/sh/sh_test.go b/sh/sh_test.go new file mode 100644 index 0000000..554b70f --- /dev/null +++ b/sh/sh_test.go @@ -0,0 +1,271 @@ +package sh + +import ( + "bytes" + "io" + "io/ioutil" + "os" + "testing" + + "github.com/yuin/gopher-lua" +) + +func captureStdOut() func() string { + old := os.Stdout // keep backup of the real stdout + r, w, _ := os.Pipe() + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + return func() string { + w.Close() + os.Stdout = old // restoring the real stdout + return <-outC + } +} + +func doString(src string, t *testing.T) string { + L := lua.NewState() + defer L.Close() + L.PreloadModule("sh", Loader) + + restorer := captureStdOut() + err := L.DoString(src) + out := restorer() + if err != nil { + t.Errorf("unable to run source: %v", err) + } + if len(out) == 0 { + return out + } + + return out[0 : len(out)-1] +} + +func TestModuleCall(t *testing.T) { + src := ` + local sh = require('sh') + sh("echo", "foo", "bar"):print() + ` + expected := "foo bar" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: %v, got: %v\nsrc: %v", expected, got, src) + } +} + +func TestIndexCall(t *testing.T) { + src := ` + local sh = require('sh') + sh.echo("foo", "bar"):print() + ` + expected := "foo bar" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: %v, got: %v\nsrc: %v", expected, got, src) + } +} + +func TestPipe(t *testing.T) { + src := ` + local sh = require('sh') + sh.echo("foo", "bar\n", "biz", "buz"):grep("foo"):print() + ` + expected := "foo bar" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestLines(t *testing.T) { + src := ` + local sh = require('sh') + for line in sh.echo("foo bar\nbiz", "buz"):lines() do + print(line) + end + ` + expected := "foo bar\nbiz buz" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestOK(t *testing.T) { + src := ` + local sh = require('sh') + sh.echo("foo"):ok() + print("ok") + ` + expected := "ok" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestNotOK(t *testing.T) { + src := ` + local sh = require('sh') + function fail() + sh.grep("-d"):ok() + end + + ok, err = pcall(fail) + print(ok) + print(err) + ` + expected := `false +:4: exit status 2` + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestSuccess(t *testing.T) { + src := ` + local sh = require('sh') + ok = sh.echo("foo"):success() + print(ok) + ` + expected := "true" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestNotSuccess(t *testing.T) { + src := ` + local sh = require('sh') + ok = sh.grep("-d"):success() + + print(ok) + ` + expected := "false" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestExitcode(t *testing.T) { + src := ` + local sh = require('sh') + exitcode = sh.echo("foo"):exitcode() + print(exitcode) + ` + expected := "0" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestNotExitcode(t *testing.T) { + src := ` + local sh = require('sh') + exitcode = sh.grep("-d"):exitcode() + + print(exitcode) + ` + expected := "2" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestStdout(t *testing.T) { + src := ` + local sh = require('sh') + out = sh.echo("foo"):stdout() + print(out) + ` + expected := "foo" + "\n" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestStderr(t *testing.T) { + src := ` + local sh = require('sh') + out = sh("./stderr.test.sh"):stderr() + print(out) + ` + expected := "foo" + "\n" + got := doString(src, t) + + if got != expected { + t.Errorf("expected: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } +} + +func TestWriteStdoutToFile(t *testing.T) { + src := ` + local sh = require('sh') + tmp = "./remove.me" + out = sh.echo("foo"):stdout(tmp) + print(out) + ` + expected := "foo" + "\n" + file := "./remove.me" + defer os.Remove(file) + got := doString(src, t) + + if got != expected { + t.Errorf("expected stdout: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } + dat, err := ioutil.ReadFile(file) + if err != nil { + t.Errorf("unable to read file: `%v`", file) + } + if string(dat) != expected { + t.Errorf("expected file: `%v`, got: `%v`\nsrc: %v", expected, string(dat), src) + } +} + +func TestWriteStderrToFile(t *testing.T) { + src := ` + local sh = require('sh') + tmp = "./remove.me" + out = sh("./stderr.test.sh"):stderr(tmp) + print(out) + ` + expected := "foo" + "\n" + file := "./remove.me" + defer os.Remove(file) + got := doString(src, t) + + if got != expected { + t.Errorf("expected stdout: `%v`, got: `%v`\nsrc: %v", expected, got, src) + } + dat, err := ioutil.ReadFile(file) + if err != nil { + t.Errorf("unable to read file: `%v`", file) + } + if string(dat) != expected { + t.Errorf("expected file: `%v`, got: `%v`\nsrc: %v", expected, string(dat), src) + } +} diff --git a/sh/shellcommand.go b/sh/shellcommand.go new file mode 100644 index 0000000..3a4b128 --- /dev/null +++ b/sh/shellcommand.go @@ -0,0 +1,341 @@ +package sh + +import ( + "bufio" + "bytes" + "fmt" + "io" + "io/ioutil" + "log" + "os" + "os/exec" + "syscall" + + "github.com/yuin/gopher-lua" +) + +const luaShTypeName = "sh" + +type shellCommand struct { + path string + args []string + command *exec.Cmd + stdout io.ReadCloser + stderr io.ReadCloser + stdin io.ReadCloser + + stdoutClosed bool + stderrClosed bool +} + +func newShellCommand(path string, args ...string) (*shellCommand, error) { + cmd := &shellCommand{ + path: path, + } + + err := cmd.Command(path, args...) + if err != nil { + return nil, err + } + return cmd, nil +} + +func (s *shellCommand) Command(path string, args ...string) error { + s.command = exec.Command(path, args...) + + stdout, err := s.command.StdoutPipe() + if err != nil { + return err + } + + stderr, err := s.command.StderrPipe() + if err != nil { + return err + } + + s.stdout = stdout + s.stderr = stderr + + return nil +} + +func (s *shellCommand) UserData(L *lua.LState) *lua.LUserData { + ud := L.NewUserData() + ud.Value = s + L.SetMetatable(ud, L.GetTypeMetatable(luaShTypeName)) + return ud +} + +func (s *shellCommand) CloseStdout() { + s.stdout.Close() + s.stdoutClosed = true +} + +func (s *shellCommand) CloseStderr() { + s.stderr.Close() + s.stderrClosed = true +} + +func (s *shellCommand) Close(std string) { + switch std { + case "stderr": + s.CloseStderr() + case "stdout": + s.CloseStdout() + default: + log.Fatalf("unable to close stream: %v", std) + } +} + +func (s *shellCommand) IsClosed(std string) bool { + switch std { + case "stderr": + return s.stderrClosed + case "stdout": + return s.stdoutClosed + default: + log.Fatalf("unable to close stream: %v", std) + } + return false +} + +// shIndex checks if it is a predefined method or if it should be interprited as +// shell command. +func shIndex(L *lua.LState) int { + index := L.CheckString(2) + + switch index { + case "print": + L.Push(L.NewFunction(shPrint)) + return 1 + case "ok": + L.Push(L.NewFunction(shOk)) + return 1 + case "lines": + L.Push(L.NewFunction(shLines)) + return 1 + case "success": + L.Push(L.NewFunction(shSuccess)) + return 1 + case "exitcode": + L.Push(L.NewFunction(shExitCode)) + return 1 + case "stdout", "stderr": + L.Push(L.NewFunction(shOutput(index))) + return 1 + default: + return shCmd(L) + } + +} + +func shCmd(L *lua.LState) int { + shellCmd := checkShellCmd(L) + index := L.CheckString(2) + + go shellCmd.command.Wait() + + cmd := &shellCommand{ + path: index, + stdin: shellCmd.stdout, + } + + L.Push(cmd.UserData(L)) + return 1 +} + +// shCall executes the shell command and returns it self +func shCall(L *lua.LState) int { + ud := L.CheckUserData(1) + args := checkStrings(L, 2) + shellCmd := checkShellCmd(L) + err := shellCmd.Command(shellCmd.path, args...) + checkError(L, err) + + if shellCmd.stdin != nil { + shellCmd.command.Stdin = shellCmd.stdin + } + + err = shellCmd.command.Start() + checkError(L, err) + + L.Push(ud) + return 1 +} + +func shOutput(std string) lua.LGFunction { + return func(L *lua.LState) int { + shellCmd := checkShellCmd(L) + file := L.OptString(2, "") + + stream := shellCmd.stdout + if std == "stderr" { + stream = shellCmd.stderr + } + + buf := new(bytes.Buffer) + _, err := buf.ReadFrom(stream) + if err != nil { + L.RaiseError("Unable to read from `%v` several times", std) + } + + shellCmd.Close(std) + + if file != "" { + err := ioutil.WriteFile(file, buf.Bytes(), 0644) + checkError(L, err) + } + + out := buf.String() + L.Push(lua.LString(out)) + return 1 + } +} + +func shOk(L *lua.LState) int { + ud := L.CheckUserData(1) + shellCmd := checkShellCmd(L) + + exitcode, err := wait(shellCmd) + checkError(L, err) + + if exitcode != 0 { + L.RaiseError("exit status %v", exitcode) + } + + L.Push(ud) + return 1 +} + +func shSuccess(L *lua.LState) int { + shellCmd := checkShellCmd(L) + + errorCode, err := wait(shellCmd) + checkError(L, err) + + L.Push(lua.LBool(errorCode == 0)) + return 1 +} + +func shExitCode(L *lua.LState) int { + shellCmd := checkShellCmd(L) + exitcode, err := wait(shellCmd) + checkError(L, err) + L.Push(lua.LNumber(exitcode)) + return 1 +} + +func shLines(L *lua.LState) int { + shellCmd := checkShellCmd(L) + std := L.OptString(2, "stdout") + + if !(std == "stdout" || std == "stderr") { + L.RaiseError("lines: illigal file handle `%v`", std) + } + if shellCmd.IsClosed(std) { + L.RaiseError("Unable to read from `%v` several times", std) + } + + stream := shellCmd.stdout + if std == "stderr" { + stream = shellCmd.stderr + } + + scanner := bufio.NewScanner(stream) + scanner.Split(bufio.ScanLines) + iterator := func(L *lua.LState) int { + if scanner.Scan() { + L.Push(lua.LString(scanner.Text())) + return 1 + } + + shellCmd.Close(std) + checkError(L, scanner.Err()) + return 0 + } + + L.Push(L.NewFunction(iterator)) + return 1 +} + +func shPrint(L *lua.LState) int { + ud := L.CheckUserData(1) + shellCmd := checkShellCmd(L) + if shellCmd.stderrClosed || shellCmd.stdoutClosed { + L.RaiseError("Unable to read from `%v` several times", "stdout/stderr") + } + + combo := io.MultiReader(shellCmd.stdout, shellCmd.stderr) + + _, err := io.Copy(os.Stdout, combo) + checkError(L, err) + + err = shellCmd.command.Wait() + if err != nil && !isExitError(err) { + L.RaiseError("Error while waiting for command to finish: %v", err) + } + + shellCmd.CloseStderr() + shellCmd.CloseStdout() + L.Push(ud) + return 1 +} + +// check if it is a shellCmd userdata as the first parmeter +func checkShellCmd(L *lua.LState) *shellCommand { + ud := L.CheckUserData(1) + shellCmd, ok := ud.Value.(*shellCommand) + if !ok { + L.Error(lua.LString("Expected the user data should be a shell command"), 0) + return nil + } + return shellCmd +} + +// converts all input parameters to strings from the n:th element +func checkStrings(L *lua.LState, n int) []string { + params := L.GetTop() + if n > params { + return []string{} + } + args := make([]string, 0, (params - n + 1)) + for i := n; i <= params; i++ { + p := L.Get(i) + if p.Type() == lua.LTUserData { + continue + } + L.CheckTypes(i, lua.LTString, lua.LTNumber) + args = append(args, p.String()) + } + return args +} + +func checkError(L *lua.LState, err error) { + if err != nil { + L.RaiseError("%v", err) + } +} + +func isExitError(err error) bool { + _, ok := err.(*exec.ExitError) + return ok +} + +func wait(shellCmd *shellCommand) (exitcode int, err error) { + if shellCmd.command.ProcessState == nil { + err := shellCmd.command.Wait() + if err != nil && !isExitError(err) { + return 0, err + } + } + + if shellCmd.command.ProcessState.Success() { + return 0, nil + } + + if status, ok := shellCmd.command.ProcessState.Sys().(syscall.WaitStatus); ok { + return status.ExitStatus(), nil + } + + return 0, fmt.Errorf("`%v`: error retreiving exit code", shellCmd.command.Args) +} diff --git a/sh/stderr.test.sh b/sh/stderr.test.sh new file mode 100755 index 0000000..7924ed3 --- /dev/null +++ b/sh/stderr.test.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +echo "foo" >&2 diff --git a/state.go b/state.go index c4f177d..e8569a4 100644 --- a/state.go +++ b/state.go @@ -4,8 +4,8 @@ import ( "fmt" "os" - "github.com/otm/blade/luasrc" "github.com/otm/blade/parser" + "github.com/otm/blade/sh" "github.com/yuin/gopher-lua" ) @@ -41,10 +41,7 @@ func setupEnv() (L *lua.LState, runner *lua.LTable, cmd *lua.LTable) { L.SetGlobal("blade", blade) emit("Preloading module: sh") - L.PreloadModule("sh", loader(luasrc.Sh)) - - emit("Preloading module: shell") - L.PreloadModule("shell", loader(luasrc.Shell)) + L.PreloadModule("sh", sh.Loader) emit("Setting up cmd\n") cmds := L.NewTable() diff --git a/string.go b/string.go index 0f8b0dd..59dba2c 100644 --- a/string.go +++ b/string.go @@ -11,6 +11,7 @@ func decorateStringLib(L *lua.LState) { mt.RawSetString("split", L.NewClosure(split)) mt.RawSetString("c", L.NewClosure(word)) mt.RawSetString("trim", L.NewClosure(trim)) + mt.RawSetString("fields", L.NewFunction(fields)) } func split(L *lua.LState) int { @@ -61,6 +62,25 @@ func word(L *lua.LState) int { return 1 } +func fields(L *lua.LState) int { + s := L.CheckString(1) + i := 0 + + parts := strings.Fields(s) + iterator := func(L *lua.LState) int { + if i == len(parts) { + return 0 + } + + L.Push(lua.LString(parts[i])) + i++ + return 1 + } + + L.Push(L.NewFunction(iterator)) + return 1 +} + func trim(L *lua.LState) int { s := L.CheckString(1) cutset := L.OptString(2, "\n ")