summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--cmd/goctl/main.go35
-rw-r--r--goctl.go202
-rw-r--r--goctl_test.go178
4 files changed, 416 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..7e508bc
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1 @@
+goctl
diff --git a/cmd/goctl/main.go b/cmd/goctl/main.go
new file mode 100644
index 0000000..1537c9a
--- /dev/null
+++ b/cmd/goctl/main.go
@@ -0,0 +1,35 @@
+package main
+
+import (
+ "flag"
+ "fmt"
+ "log"
+ "net"
+ "os"
+ "strings"
+
+ "github.com/bjc/goctl"
+)
+
+func main() {
+ var path = flag.String("path", "", "Socket path for sending commands (required).")
+ flag.Parse()
+
+ if *path == "" {
+ flag.Usage()
+ os.Exit(1)
+ }
+
+ c, err := net.Dial("unix", *path)
+ if err != nil {
+ log.Fatalf("Couldn't connect to %s: %s.", *path, err)
+ }
+ defer c.Close()
+
+ goctl.Write(c, []byte(strings.Join(flag.Args(), " ")))
+ if buf, err := goctl.Read(c); err != nil {
+ log.Fatalf("Error reading response from command: %s.", err)
+ } else {
+ fmt.Println(string(buf))
+ }
+}
diff --git a/goctl.go b/goctl.go
new file mode 100644
index 0000000..85b367f
--- /dev/null
+++ b/goctl.go
@@ -0,0 +1,202 @@
+package goctl
+
+import (
+ "encoding/binary"
+ "fmt"
+ "io"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "sync/atomic"
+ "time"
+
+ "gopkg.in/inconshreveable/log15.v2"
+)
+
+// How long to wait for a response before returning an error.
+const timeout = 100 * time.Millisecond
+
+// Built-in commands which can't be overridden.
+const (
+ cmdPing = "ping"
+ cmdPID = "pid"
+)
+
+var (
+ pid int
+ lastid uint64
+ Logger log15.Logger
+)
+
+type Goctl struct {
+ logger log15.Logger
+ path string
+ listener net.Listener
+ handlers map[string]*Handler
+}
+
+type Handler struct {
+ Name string
+ Fn func([]string) string
+}
+
+func init() {
+ pid = os.Getpid()
+ Logger = log15.New()
+ Logger.SetHandler(log15.DiscardHandler())
+}
+
+func NewGoctl(path string) Goctl {
+ return Goctl{
+ logger: Logger.New("id", atomic.AddUint64(&lastid, 1)),
+ path: path,
+ handlers: make(map[string]*Handler),
+ }
+}
+
+func (gc *Goctl) Start() error {
+ gc.logger.Info("Starting command listener.")
+
+ if pid := gc.isAlreadyRunning(); pid != "" {
+ gc.logger.Crit("Command listener already running.", "pid", pid)
+ return fmt.Errorf("already running on pid %s", pid)
+ } else {
+ os.Remove(gc.path)
+ }
+
+ var err error
+ gc.listener, err = net.Listen("unix", gc.path)
+ if err != nil {
+ gc.logger.Crit("Couldn't listen on socket.", "path", gc.path, "error", err)
+ return err
+ }
+
+ go gc.acceptor()
+ return nil
+}
+
+func (gc *Goctl) Stop() {
+ gc.logger.Info("Stopping command listener.")
+
+ if gc.listener != nil {
+ gc.listener.Close()
+ os.Remove(gc.path)
+ }
+}
+
+func (gc *Goctl) AddHandler(name string, fn func([]string) string) {
+ gc.AddHandlers([]*Handler{{name, fn}})
+}
+
+func (gc *Goctl) AddHandlers(handlers []*Handler) {
+ for _, h := range handlers {
+ gc.handlers[h.Name] = h
+ }
+}
+
+func Read(r io.Reader) ([]byte, error) {
+ var n uint16
+ if err := binary.Read(r, binary.BigEndian, &n); err != nil {
+ return nil, err
+ }
+
+ p := make([]byte, n)
+ if n == 0 {
+ // Don't attempt to read no data or EOF will be
+ // returned. Instead, just return an empty byte slice.
+ return p, nil
+ }
+
+ if _, err := r.Read(p[:]); err != nil {
+ return nil, err
+ }
+ return p, nil
+}
+
+func Write(w io.Writer, p []byte) error {
+ if err := binary.Write(w, binary.BigEndian, uint16(len(p))); err != nil {
+ return err
+ }
+ if _, err := w.Write(p); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (gc *Goctl) acceptor() {
+ for {
+ c, err := gc.listener.Accept()
+ if err != nil {
+ gc.logger.Error("Error accepting connection.", "error", err)
+ return
+ }
+ go gc.reader(c)
+ }
+}
+
+func (gc *Goctl) reader(c io.ReadWriteCloser) error {
+ defer gc.logger.Info("Connection closed.")
+ defer c.Close()
+
+ gc.logger.Info("New connection.")
+ for {
+ buf, err := Read(c)
+ if err != nil {
+ gc.logger.Error("Error reading from connection.", "error", err)
+ return err
+ }
+
+ cmd := strings.Split(string(buf), " ")
+ gc.logger.Debug("Got command.", "cmd", cmd)
+ var resp string
+ switch cmd[0] {
+ case cmdPing:
+ resp = "pong"
+ case cmdPID:
+ resp = strconv.Itoa(pid)
+ default:
+ h := gc.handlers[cmd[0]]
+ if h == nil {
+ resp = "ERROR: unknown command"
+ } else {
+ resp = h.Fn(cmd[1:])
+ }
+ }
+ gc.logger.Debug("Responding.", "resp", resp)
+ Write(c, []byte(resp))
+ }
+ /* NOTREACHED */
+}
+
+func (gc *Goctl) isAlreadyRunning() string {
+ c, err := net.Dial("unix", gc.path)
+ if err != nil {
+ return ""
+ }
+ defer c.Close()
+
+ dataChan := make(chan []byte, 1)
+ c.Write([]byte(cmdPID))
+ go func() {
+ buf := make([]byte, 1024)
+ n, err := c.Read(buf[:])
+ if err == nil {
+ dataChan <- buf[0:n]
+ }
+ }()
+
+ timeoutChan := make(chan bool, 1)
+ go func() {
+ time.Sleep(timeout)
+ timeoutChan <- true
+ }()
+
+ select {
+ case buf := <-dataChan:
+ return string(buf)
+ case <-timeoutChan:
+ gc.logger.Info("Timed out checking PID of existing service.", "path", gc.path)
+ return ""
+ }
+}
diff --git a/goctl_test.go b/goctl_test.go
new file mode 100644
index 0000000..a730500
--- /dev/null
+++ b/goctl_test.go
@@ -0,0 +1,178 @@
+package goctl
+
+import (
+ "net"
+ "reflect"
+ "testing"
+)
+
+const SOCKPATH = "/tmp/xmppbot"
+
+func TestPing(t *testing.T) {
+ gc := NewGoctl(SOCKPATH)
+ if err := gc.Start(); err != nil {
+ t.Fatalf("Couldn't start: %s.", err)
+ }
+ defer gc.Stop()
+
+ c, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ t.Fatalf("Couldn't open %s: %s.", SOCKPATH, err)
+ }
+ defer c.Close()
+
+ buf := []byte("ping")
+ Write(c, buf)
+
+ buf, err = Read(c)
+ if err != nil {
+ t.Fatalf("Couldn't read from socket: %s.", err)
+ }
+ s := string(buf)
+ if s != "pong" {
+ t.Errorf("Sending ping: got '%s', expected 'pong'.", s)
+ }
+}
+
+func TestAlreadyRunning(t *testing.T) {
+ gc := NewGoctl(SOCKPATH)
+ if err := gc.Start(); err != nil {
+ t.Errorf("Couldn't start: %s.", err)
+ }
+ defer gc.Stop()
+
+ bar := NewGoctl(SOCKPATH)
+ if bar.Start() == nil {
+ t.Errorf("Started bot when already running.")
+ }
+ defer bar.Stop()
+}
+
+func TestSimulCommands(t *testing.T) {
+ gc := NewGoctl(SOCKPATH)
+ if err := gc.Start(); err != nil {
+ t.Fatalf("Couldn't start: %s.", err)
+ }
+ defer gc.Stop()
+
+ c0, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ t.Fatalf("Coudln't open first connection to %s: %s.", SOCKPATH, err)
+ }
+ defer c0.Close()
+
+ c1, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ t.Fatalf("Coudln't open second connection to %s: %s.", SOCKPATH, err)
+ }
+ defer c1.Close()
+
+ Write(c0, []byte("ping"))
+ Write(c1, []byte("ping"))
+
+ if buf, err := Read(c0); err != nil {
+ t.Errorf("Couldn't read from first connection: %s.", err)
+ } else if string(buf) != "pong" {
+ t.Fatalf("Sending ping on first connection: got '%s', expected 'pong'.", string(buf))
+ }
+ if buf, err := Read(c1); err != nil {
+ t.Fatalf("Couldn't read from second connection: %s.", err)
+ } else if string(buf) != "pong" {
+ t.Fatalf("Sending ping on second connection: got '%s', expected 'pong'.", string(buf))
+ }
+}
+
+func TestAddHandler(t *testing.T) {
+ gc := NewGoctl(SOCKPATH)
+ if err := gc.Start(); err != nil {
+ t.Fatalf("Couldn't start: %s", err)
+ }
+ defer gc.Stop()
+
+ gc.AddHandler("foo", func(args []string) string {
+ if !reflect.DeepEqual(args, []string{"bar", "baz"}) {
+ t.Errorf("Got %v, expected ['bar', 'baz']", args)
+ }
+ return "bar baz"
+ })
+
+ c, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ t.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err)
+ }
+ defer c.Close()
+ Write(c, []byte("foo bar baz"))
+
+ if buf, err := Read(c); err != nil {
+ t.Errorf("Couldn't read from connection: %s.", err)
+ } else if string(buf) != "bar baz" {
+ t.Errorf("Got: %s, expected 'bar baz'")
+ }
+}
+
+func TestAddHandlers(t *testing.T) {
+ gc := NewGoctl(SOCKPATH)
+ if err := gc.Start(); err != nil {
+ t.Fatalf("Couldn't start: %s", err)
+ }
+ defer gc.Stop()
+
+ gc.AddHandlers([]*Handler{
+ {"foo", func(args []string) string {
+ if !reflect.DeepEqual(args, []string{"bar", "baz"}) {
+ t.Errorf("Got %v, expected ['bar', 'baz']", args)
+ }
+ return ""
+ }},
+ {"bar", func(args []string) string {
+ if !reflect.DeepEqual(args, []string{"baz", "pham"}) {
+ t.Errorf("Got %v, expected ['baz', 'pham']", args)
+ }
+ return "wauug"
+ }}})
+
+ c, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ t.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err)
+ }
+ defer c.Close()
+ Write(c, []byte("foo bar baz"))
+ Write(c, []byte("bar baz pham"))
+
+ if buf, err := Read(c); err != nil {
+ t.Errorf("Couldn't read from connection: %s.", err)
+ } else if string(buf) != "" {
+ t.Errorf("Got: %s, expected ''", string(buf))
+ }
+
+ if buf, err := Read(c); err != nil {
+ t.Errorf("Couldn't read from connection: %s.", err)
+ } else if string(buf) != "wauug" {
+ t.Errorf("Got: %s, expected 'wauug'", string(buf))
+ }
+}
+
+func BenchmarkStartStop(b *testing.B) {
+ gc := NewGoctl(SOCKPATH)
+ for i := 0; i < b.N; i++ {
+ gc.Start()
+ gc.Stop()
+ }
+}
+
+func BenchmarkPing(b *testing.B) {
+ gc := NewGoctl(SOCKPATH)
+ gc.Start()
+ defer gc.Stop()
+
+ c, err := net.Dial("unix", SOCKPATH)
+ if err != nil {
+ b.Fatalf("Coudln't open connection to %s: %s.", SOCKPATH, err)
+ }
+ defer c.Close()
+
+ for i := 0; i < b.N; i++ {
+ Write(c, []byte("ping"))
+ Read(c)
+ }
+}