diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | cmd/goctl/main.go | 35 | ||||
-rw-r--r-- | goctl.go | 202 | ||||
-rw-r--r-- | goctl_test.go | 178 |
4 files changed, 416 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7e508bc --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +goctl diff --git a/cmd/goctl/main.go b/cmd/goctl/main.go new file mode 100644 index 0000000..1537c9a --- /dev/null +++ b/cmd/goctl/main.go @@ -0,0 +1,35 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net" + "os" + "strings" + + "github.com/bjc/goctl" +) + +func main() { + var path = flag.String("path", "", "Socket path for sending commands (required).") + flag.Parse() + + if *path == "" { + flag.Usage() + os.Exit(1) + } + + c, err := net.Dial("unix", *path) + if err != nil { + log.Fatalf("Couldn't connect to %s: %s.", *path, err) + } + defer c.Close() + + goctl.Write(c, []byte(strings.Join(flag.Args(), " "))) + if buf, err := goctl.Read(c); err != nil { + log.Fatalf("Error reading response from command: %s.", err) + } else { + fmt.Println(string(buf)) + } +} diff --git a/goctl.go b/goctl.go new file mode 100644 index 0000000..85b367f --- /dev/null +++ b/goctl.go @@ -0,0 +1,202 @@ +package goctl + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "os" + "strconv" + "strings" + "sync/atomic" + "time" + + "gopkg.in/inconshreveable/log15.v2" +) + +// How long to wait for a response before returning an error. +const timeout = 100 * time.Millisecond + +// Built-in commands which can't be overridden. +const ( + cmdPing = "ping" + cmdPID = "pid" +) + +var ( + pid int + lastid uint64 + Logger log15.Logger +) + +type Goctl struct { + logger log15.Logger + path string + listener net.Listener + handlers map[string]*Handler +} + +type Handler struct { + Name string + Fn func([]string) string +} + +func init() { + pid = os.Getpid() + Logger = log15.New() + Logger.SetHandler(log15.DiscardHandler()) +} + +func NewGoctl(path string) Goctl { + return Goctl{ + logger: Logger.New("id", atomic.AddUint64(&lastid, 1)), + path: path, + handlers: make(map[string]*Handler), + } +} + +func (gc *Goctl) Start() error { + gc.logger.Info("Starting command listener.") + + if pid := gc.isAlreadyRunning(); pid != "" { + gc.logger.Crit("Command listener already running.", "pid", pid) + return fmt.Errorf("already running on pid %s", pid) + } else { + os.Remove(gc.path) + } + + var err error + gc.listener, err = net.Listen("unix", gc.path) + if err != nil { + gc.logger.Crit("Couldn't listen on socket.", "path", gc.path, "error", err) + return err + } + + go gc.acceptor() + return nil +} + +func (gc *Goctl) Stop() { + gc.logger.Info("Stopping command listener.") + + if gc.listener != nil { + gc.listener.Close() + os.Remove(gc.path) + } +} + +func (gc *Goctl) AddHandler(name string, fn func([]string) string) { + gc.AddHandlers([]*Handler{{name, fn}}) +} + +func (gc *Goctl) AddHandlers(handlers []*Handler) { + for _, h := range handlers { + gc.handlers[h.Name] = h + } +} + +func Read(r io.Reader) ([]byte, error) { + var n uint16 + if err := binary.Read(r, binary.BigEndian, &n); err != nil { + return nil, err + } + + p := make([]byte, n) + if n == 0 { + // Don't attempt to read no data or EOF will be + // returned. Instead, just return an empty byte slice. + return p, nil + } + + if _, err := r.Read(p[:]); err != nil { + return nil, err + } + return p, nil +} + +func Write(w io.Writer, p []byte) error { + if err := binary.Write(w, binary.BigEndian, uint16(len(p))); err != nil { + return err + } + if _, err := w.Write(p); err != nil { + return err + } + return nil +} + +func (gc *Goctl) acceptor() { + for { + c, err := gc.listener.Accept() + if err != nil { + gc.logger.Error("Error accepting connection.", "error", err) + return + } + go gc.reader(c) + } +} + +func (gc *Goctl) reader(c io.ReadWriteCloser) error { + defer gc.logger.Info("Connection closed.") + defer c.Close() + + gc.logger.Info("New connection.") + for { + buf, err := Read(c) + if err != nil { + gc.logger.Error("Error reading from connection.", "error", err) + return err + } + + cmd := strings.Split(string(buf), " ") + gc.logger.Debug("Got command.", "cmd", cmd) + var resp string + switch cmd[0] { + case cmdPing: + resp = "pong" + case cmdPID: + resp = strconv.Itoa(pid) + default: + h := gc.handlers[cmd[0]] + if h == nil { + resp = "ERROR: unknown command" + } else { + resp = h.Fn(cmd[1:]) + } + } + gc.logger.Debug("Responding.", "resp", resp) + Write(c, []byte(resp)) + } + /* NOTREACHED */ +} + +func (gc *Goctl) isAlreadyRunning() string { + c, err := net.Dial("unix", gc.path) + if err != nil { + return "" + } + defer c.Close() + + dataChan := make(chan []byte, 1) + c.Write([]byte(cmdPID)) + go func() { + buf := make([]byte, 1024) + n, err := c.Read(buf[:]) + if err == nil { + dataChan <- buf[0:n] + } + }() + + timeoutChan := make(chan bool, 1) + go func() { + time.Sleep(timeout) + timeoutChan <- true + }() + + select { + case buf := <-dataChan: + return string(buf) + case <-timeoutChan: + gc.logger.Info("Timed out checking PID of existing service.", "path", gc.path) + return "" + } +} diff --git a/goctl_test.go b/goctl_test.go new file mode 100644 index 0000000..a730500 --- /dev/null +++ b/goctl_test.go @@ -0,0 +1,178 @@ +package goctl + +import ( + "net" + "reflect" + "testing" +) + +const SOCKPATH = "/tmp/xmppbot" + +func TestPing(t *testing.T) { + gc := NewGoctl(SOCKPATH) + if err := gc.Start(); err != nil { + t.Fatalf("Couldn't start: %s.", err) + } + defer gc.Stop() + + c, err := net.Dial("unix", SOCKPATH) + if err != nil { + t.Fatalf("Couldn't open %s: %s.", SOCKPATH, err) + } + defer c.Close() + + buf := []byte("ping") + Write(c, buf) + + buf, err = Read(c) + if err != nil { + t.Fatalf("Couldn't read from socket: %s.", err) + } + s := string(buf) + if s != "pong" { + t.Errorf("Sending ping: got '%s', expected 'pong'.", s) + } +} + +func TestAlreadyRunning(t *testing.T) { + gc := NewGoctl(SOCKPATH) + if err := gc.Start(); err != nil { + t.Errorf("Couldn't start: %s.", err) + } + defer gc.Stop() + + bar := NewGoctl(SOCKPATH) + if bar.Start() == nil { + t.Errorf("Started bot when already running.") + } + defer bar.Stop() +} + +func TestSimulCommands(t *testing.T) { + gc := NewGoctl(SOCKPATH) + if err := gc.Start(); err != nil { + t.Fatalf("Couldn't start: %s.", err) + } + defer gc.Stop() + + c0, err := net.Dial("unix", SOCKPATH) + if err != nil { + t.Fatalf("Coudln't open first connection to %s: %s.", SOCKPATH, err) + } + defer c0.Close() + + c1, err := net.Dial("unix", SOCKPATH) + if err != nil { + t.Fatalf("Coudln't open second connection to %s: %s.", SOCKPATH, err) + } + defer c1.Close() + + Write(c0, []byte("ping")) + Write(c1, []byte("ping")) + + if buf, err := Read(c0); err != nil { + t.Errorf("Couldn't read from first connection: %s.", err) + } else if string(buf) != "pong" { + t.Fatalf("Sending ping on first connection: got '%s', expected 'pong'.", string(buf)) + } + if buf, err := Read(c1); err != nil { + t.Fatalf("Couldn't read from second connection: %s.", err) + } else if string(buf) != "pong" { + t.Fatalf("Sending ping on second connection: got '%s', expected 'pong'.", string(buf)) + } +} + +func TestAddHandler(t *testing.T) { + gc := NewGoctl(SOCKPATH) + if err := gc.Start(); err != nil { + t.Fatalf("Couldn't start: %s", err) + } + defer gc.Stop() + + gc.AddHandler("foo", func(args []string) string { + if !reflect.DeepEqual(args, []string{"bar", "baz"}) { + t.Errorf("Got %v, expected ['bar', 'baz']", args) + } + return "bar baz" + }) + + c, err := net.Dial("unix", SOCKPATH) + if err != nil { + t.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err) + } + defer c.Close() + Write(c, []byte("foo bar baz")) + + if buf, err := Read(c); err != nil { + t.Errorf("Couldn't read from connection: %s.", err) + } else if string(buf) != "bar baz" { + t.Errorf("Got: %s, expected 'bar baz'") + } +} + +func TestAddHandlers(t *testing.T) { + gc := NewGoctl(SOCKPATH) + if err := gc.Start(); err != nil { + t.Fatalf("Couldn't start: %s", err) + } + defer gc.Stop() + + gc.AddHandlers([]*Handler{ + {"foo", func(args []string) string { + if !reflect.DeepEqual(args, []string{"bar", "baz"}) { + t.Errorf("Got %v, expected ['bar', 'baz']", args) + } + return "" + }}, + {"bar", func(args []string) string { + if !reflect.DeepEqual(args, []string{"baz", "pham"}) { + t.Errorf("Got %v, expected ['baz', 'pham']", args) + } + return "wauug" + }}}) + + c, err := net.Dial("unix", SOCKPATH) + if err != nil { + t.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err) + } + defer c.Close() + Write(c, []byte("foo bar baz")) + Write(c, []byte("bar baz pham")) + + if buf, err := Read(c); err != nil { + t.Errorf("Couldn't read from connection: %s.", err) + } else if string(buf) != "" { + t.Errorf("Got: %s, expected ''", string(buf)) + } + + if buf, err := Read(c); err != nil { + t.Errorf("Couldn't read from connection: %s.", err) + } else if string(buf) != "wauug" { + t.Errorf("Got: %s, expected 'wauug'", string(buf)) + } +} + +func BenchmarkStartStop(b *testing.B) { + gc := NewGoctl(SOCKPATH) + for i := 0; i < b.N; i++ { + gc.Start() + gc.Stop() + } +} + +func BenchmarkPing(b *testing.B) { + gc := NewGoctl(SOCKPATH) + gc.Start() + defer gc.Stop() + + c, err := net.Dial("unix", SOCKPATH) + if err != nil { + b.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err) + } + defer c.Close() + + for i := 0; i < b.N; i++ { + Write(c, []byte("ping")) + Read(c) + } +} |