]> Cypherpunks repositories - gostls13.git/commitdiff
log/syslog: retry once if write fails
authorJeff R. Allen <jra@nella.org>
Tue, 5 Feb 2013 17:54:01 +0000 (09:54 -0800)
committerBrad Fitzpatrick <bradfitz@golang.org>
Tue, 5 Feb 2013 17:54:01 +0000 (09:54 -0800)
Implements deferred connections + single-attempt automatic
retry. Based on CL 5078042 from kuroneko.

Fixes #2264.

R=mikioh.mikioh, rsc, bradfitz
CC=golang-dev
https://golang.org/cl/6782140

src/pkg/log/syslog/syslog.go
src/pkg/log/syslog/syslog_test.go
src/pkg/log/syslog/syslog_unix.go

index 98b9c5f6e8773d799febe617f8584fc642803618..8bdd9825e1cdb045800047c12a9cc5df1d4e2d65 100644 (file)
@@ -6,7 +6,11 @@
 
 // Package syslog provides a simple interface to the system log
 // service. It can send messages to the syslog daemon using UNIX
-// domain sockets, UDP, or TCP connections.
+// domain sockets, UDP or TCP.
+//
+// Only one call to Dial is necessary. On write failures,
+// the syslog client will attempt to reconnect to the server
+// and write again.
 package syslog
 
 import (
@@ -15,6 +19,8 @@ import (
        "log"
        "net"
        "os"
+       "strings"
+       "sync"
        "time"
 )
 
@@ -78,15 +84,10 @@ type Writer struct {
        priority Priority
        tag      string
        hostname string
-       conn     serverConn
-}
-
-type serverConn interface {
-       writeString(p Priority, hostname, tag, s string) (int, error)
-       close() error
-}
+       network  string
+       raddr    string
 
-type netConn struct {
+       mu   sync.Mutex // guards conn
        conn net.Conn
 }
 
@@ -101,7 +102,7 @@ func New(priority Priority, tag string) (w *Writer, err error) {
 // address raddr on the network net.  Each write to the returned
 // writer sends a log message with the given facility, severity and
 // tag.
-func Dial(network, raddr string, priority Priority, tag string) (w *Writer, err error) {
+func Dial(network, raddr string, priority Priority, tag string) (*Writer, error) {
        if priority < 0 || priority > LOG_LOCAL7|LOG_DEBUG {
                return nil, errors.New("log/syslog: invalid priority")
        }
@@ -109,117 +110,160 @@ func Dial(network, raddr string, priority Priority, tag string) (w *Writer, err
        if tag == "" {
                tag = os.Args[0]
        }
-
        hostname, _ := os.Hostname()
 
-       var conn serverConn
-       if network == "" {
-               conn, err = unixSyslog()
-               if hostname == "" {
-                       hostname = "localhost"
-               }
-       } else {
-               var c net.Conn
-               c, err = net.Dial(network, raddr)
-               conn = netConn{c}
-               if hostname == "" {
-                       hostname = c.LocalAddr().String()
-               }
+       w := &Writer{
+               priority: priority,
+               tag:      tag,
+               hostname: hostname,
+               network:  network,
+               raddr:    raddr,
        }
+
+       w.mu.Lock()
+       defer w.mu.Unlock()
+
+       err := w.connect()
        if err != nil {
                return nil, err
        }
+       return w, err
+}
+
+// connect makes a connection to the syslog server.
+// It must be called with w.mu held.
+func (w *Writer) connect() (err error) {
+       if w.conn != nil {
+               // ignore err from close, it makes sense to continue anyway
+               w.conn.Close()
+               w.conn = nil
+       }
 
-       return &Writer{priority: priority, tag: tag, hostname: hostname, conn: conn}, nil
+       if w.network == "" {
+               w.conn, err = unixSyslog()
+               if w.hostname == "" {
+                       w.hostname = "localhost"
+               }
+       } else {
+               var c net.Conn
+               c, err = net.Dial(w.network, w.raddr)
+               if err == nil {
+                       w.conn = c
+                       if w.hostname == "" {
+                               w.hostname = c.LocalAddr().String()
+                       }
+               }
+       }
+       return
 }
 
 // Write sends a log message to the syslog daemon.
 func (w *Writer) Write(b []byte) (int, error) {
-       return w.writeString(w.priority, string(b))
+       return w.writeAndRetry(w.priority, string(b))
 }
 
-func (w *Writer) Close() error { return w.conn.close() }
+// Close closes a connection to the syslog daemon.
+func (w *Writer) Close() error {
+       w.mu.Lock()
+       defer w.mu.Unlock()
+
+       if w.conn != nil {
+               err := w.conn.Close()
+               w.conn = nil
+               return err
+       }
+       return nil
+}
 
 // Emerg logs a message with severity LOG_EMERG, ignoring the severity
 // passed to New.
 func (w *Writer) Emerg(m string) (err error) {
-       _, err = w.writeString(LOG_EMERG, m)
+       _, err = w.writeAndRetry(LOG_EMERG, m)
        return err
 }
 
 // Alert logs a message with severity LOG_ALERT, ignoring the severity
 // passed to New.
 func (w *Writer) Alert(m string) (err error) {
-       _, err = w.writeString(LOG_ALERT, m)
+       _, err = w.writeAndRetry(LOG_ALERT, m)
        return err
 }
 
 // Crit logs a message with severity LOG_CRIT, ignoring the severity
 // passed to New.
 func (w *Writer) Crit(m string) (err error) {
-       _, err = w.writeString(LOG_CRIT, m)
+       _, err = w.writeAndRetry(LOG_CRIT, m)
        return err
 }
 
 // Err logs a message with severity LOG_ERR, ignoring the severity
 // passed to New.
 func (w *Writer) Err(m string) (err error) {
-       _, err = w.writeString(LOG_ERR, m)
+       _, err = w.writeAndRetry(LOG_ERR, m)
        return err
 }
 
 // Wanring logs a message with severity LOG_WARNING, ignoring the
 // severity passed to New.
 func (w *Writer) Warning(m string) (err error) {
-       _, err = w.writeString(LOG_WARNING, m)
+       _, err = w.writeAndRetry(LOG_WARNING, m)
        return err
 }
 
 // Notice logs a message with severity LOG_NOTICE, ignoring the
 // severity passed to New.
 func (w *Writer) Notice(m string) (err error) {
-       _, err = w.writeString(LOG_NOTICE, m)
+       _, err = w.writeAndRetry(LOG_NOTICE, m)
        return err
 }
 
 // Info logs a message with severity LOG_INFO, ignoring the severity
 // passed to New.
 func (w *Writer) Info(m string) (err error) {
-       _, err = w.writeString(LOG_INFO, m)
+       _, err = w.writeAndRetry(LOG_INFO, m)
        return err
 }
 
 // Debug logs a message with severity LOG_DEBUG, ignoring the severity
 // passed to New.
 func (w *Writer) Debug(m string) (err error) {
-       _, err = w.writeString(LOG_DEBUG, m)
+       _, err = w.writeAndRetry(LOG_DEBUG, m)
        return err
 }
 
-func (w *Writer) writeString(p Priority, s string) (int, error) {
-       return w.conn.writeString((w.priority&facilityMask)|(p&severityMask),
-               w.hostname, w.tag, s)
+func (w *Writer) writeAndRetry(p Priority, s string) (int, error) {
+       pr := (w.priority & facilityMask) | (p & severityMask)
+
+       w.mu.Lock()
+       defer w.mu.Unlock()
+
+       if w.conn != nil {
+               if n, err := w.write(pr, s); err == nil {
+                       return n, err
+               }
+       }
+       if err := w.connect(); err != nil {
+               return 0, err
+       }
+       return w.write(pr, s)
 }
 
-// writeString: generates and writes a syslog formatted string. The
+// write generates and writes a syslog formatted string. The
 // format is as follows: <PRI>TIMESTAMP HOSTNAME TAG[PID]: MSG
-func (n netConn) writeString(p Priority, hostname, tag, msg string) (int, error) {
+func (w *Writer) write(p Priority, msg string) (int, error) {
+       // ensure it ends in a \n
        nl := ""
-       if len(msg) == 0 || msg[len(msg)-1] != '\n' {
+       if !strings.HasSuffix(msg, "\n") {
                nl = "\n"
        }
+
        timestamp := time.Now().Format(time.RFC3339)
-       if _, err := fmt.Fprintf(n.conn, "<%d>%s %s %s[%d]: %s%s", p, timestamp, hostname,
-               tag, os.Getpid(), msg, nl); err != nil {
-               return 0, err
-       }
+       fmt.Fprintf(w.conn, "<%d>%s %s %s[%d]: %s%s",
+               p, timestamp, w.hostname,
+               w.tag, os.Getpid(), msg, nl)
        return len(msg), nil
 }
 
-func (n netConn) close() error {
-       return n.conn.Close()
-}
-
 // NewLogger creates a log.Logger whose output is written to
 // the system log service with the specified priority. The logFlag
 // argument is the flag set passed through to log.New to create
index 3770b34d41c89dea4b1a7674d54cf3a63e1d8c96..51fbde2ae967c3133b8375e226bacd8843a18d50 100644 (file)
 package syslog
 
 import (
+       "bufio"
        "fmt"
        "io"
+       "io/ioutil"
        "log"
        "net"
        "os"
+       "sync"
        "testing"
        "time"
 )
 
-var serverAddr string
-
-func runSyslog(c net.PacketConn, done chan<- string) {
+func runPktSyslog(c net.PacketConn, done chan<- string) {
        var buf [4096]byte
        var rcvd string
+       ct := 0
        for {
+               var n int
+               var err error
+
                c.SetReadDeadline(time.Now().Add(100 * time.Millisecond))
-               n, _, err := c.ReadFrom(buf[:])
+               n, _, err = c.ReadFrom(buf[:])
                rcvd += string(buf[:n])
                if err != nil {
+                       if oe, ok := err.(*net.OpError); ok {
+                               if ct < 3 && oe.Temporary() {
+                                       ct++
+                                       continue
+                               }
+                       }
                        break
                }
        }
+       c.Close()
        done <- rcvd
 }
 
-func startServer(done chan<- string) {
-       c, e := net.ListenPacket("udp", "127.0.0.1:0")
-       if e != nil {
-               log.Fatalf("net.ListenPacket failed udp :0 %v", e)
+var crashy = false
+
+func runStreamSyslog(l net.Listener, done chan<- string) {
+       for {
+               var c net.Conn
+               var err error
+               if c, err = l.Accept(); err != nil {
+                       fmt.Print(err)
+                       return
+               }
+               go func(c net.Conn) {
+                       b := bufio.NewReader(c)
+                       for ct := 1; !crashy || ct&7 != 0; ct++ {
+                               s, err := b.ReadString('\n')
+                               if err != nil {
+                                       break
+                               }
+                               done <- s
+                       }
+                       c.Close()
+               }(c)
+       }
+}
+
+func startServer(n, la string, done chan<- string) (addr string) {
+       if n == "udp" || n == "tcp" {
+               la = "127.0.0.1:0"
+       } else {
+               // unix and unixgram: choose an address if none given
+               if la == "" {
+                       // use ioutil.TempFile to get a name that is unique
+                       f, err := ioutil.TempFile("", "syslogtest")
+                       if err != nil {
+                               log.Fatal("TempFile: ", err)
+                       }
+                       f.Close()
+                       la = f.Name()
+               }
+               os.Remove(la)
+       }
+
+       if n == "udp" || n == "unixgram" {
+               l, e := net.ListenPacket(n, la)
+               if e != nil {
+                       log.Fatalf("startServer failed: %v", e)
+               }
+               addr = l.LocalAddr().String()
+               go runPktSyslog(l, done)
+       } else {
+               l, e := net.Listen(n, la)
+               if e != nil {
+                       log.Fatalf("startServer failed: %v", e)
+               }
+               addr = l.Addr().String()
+               go runStreamSyslog(l, done)
+       }
+       return
+}
+
+func TestWithSimulated(t *testing.T) {
+       msg := "Test 123"
+       transport := []string{"unix", "unixgram", "udp", "tcp"}
+
+       for _, tr := range transport {
+               done := make(chan string)
+               addr := startServer(tr, "", done)
+               s, err := Dial(tr, addr, LOG_INFO|LOG_USER, "syslog_test")
+               if err != nil {
+                       t.Fatalf("Dial() failed: %v", err)
+               }
+               err = s.Info(msg)
+               if err != nil {
+                       t.Fatalf("log failed: %v", err)
+               }
+               check(t, msg, <-done)
+               s.Close()
+       }
+}
+
+func TestFlap(t *testing.T) {
+       net := "unix"
+       done := make(chan string)
+       addr := startServer(net, "", done)
+
+       s, err := Dial(net, addr, LOG_INFO|LOG_USER, "syslog_test")
+       if err != nil {
+               t.Fatalf("Dial() failed: %v", err)
+       }
+       msg := "Moo 2"
+       err = s.Info(msg)
+       if err != nil {
+               t.Fatalf("log failed: %v", err)
+       }
+       check(t, msg, <-done)
+
+       // restart the server
+       startServer(net, addr, done)
+
+       // and try retransmitting
+       msg = "Moo 3"
+       err = s.Info(msg)
+       if err != nil {
+               t.Fatalf("log failed: %v", err)
        }
-       serverAddr = c.LocalAddr().String()
-       go runSyslog(c, done)
+       check(t, msg, <-done)
+
+       s.Close()
 }
 
 func TestNew(t *testing.T) {
@@ -49,7 +161,8 @@ func TestNew(t *testing.T) {
                // Depends on syslog daemon running, and sometimes it's not.
                t.Skip("skipping syslog test during -short")
        }
-       s, err := New(LOG_INFO|LOG_USER, "")
+
+       s, err := New(LOG_INFO|LOG_USER, "the_tag")
        if err != nil {
                t.Fatalf("New() failed: %s", err)
        }
@@ -86,24 +199,15 @@ func TestDial(t *testing.T) {
        l.Close()
 }
 
-func TestUDPDial(t *testing.T) {
-       done := make(chan string)
-       startServer(done)
-       l, err := Dial("udp", serverAddr, LOG_USER|LOG_INFO, "syslog_test")
-       if err != nil {
-               t.Fatalf("syslog.Dial() failed: %s", err)
-       }
-       msg := "udp test"
-       l.Info(msg)
-       expected := fmt.Sprintf("<%d>", LOG_USER+LOG_INFO) + "%s %s syslog_test[%d]: udp test\n"
-       rcvd := <-done
-       var parsedHostname, timestamp string
-       var pid int
+func check(t *testing.T, in, out string) {
+       tmpl := fmt.Sprintf("<%d>%%s %%s syslog_test[%%d]: %s\n", LOG_USER+LOG_INFO, in)
        if hostname, err := os.Hostname(); err != nil {
-               t.Fatalf("Error retrieving hostname")
+               t.Error("Error retrieving hostname")
        } else {
-               if n, err := fmt.Sscanf(rcvd, expected, &timestamp, &parsedHostname, &pid); n != 3 || err != nil || hostname != parsedHostname {
-                       t.Fatalf("'%q', didn't match '%q' (%d, %s)", rcvd, expected, n, err)
+               var parsedHostname, timestamp string
+               var pid int
+               if n, err := fmt.Sscanf(out, tmpl, &timestamp, &parsedHostname, &pid); n != 3 || err != nil || hostname != parsedHostname {
+                       t.Errorf("Got %q, does not match template %q (%d %s)", out, tmpl, n, err)
                }
        }
 }
@@ -126,24 +230,95 @@ func TestWrite(t *testing.T) {
        } else {
                for _, test := range tests {
                        done := make(chan string)
-                       startServer(done)
-                       l, err := Dial("udp", serverAddr, test.pri, test.pre)
+                       addr := startServer("udp", "", done)
+                       l, err := Dial("udp", addr, test.pri, test.pre)
                        if err != nil {
-                               t.Fatalf("syslog.Dial() failed: %s", err)
+                               t.Fatalf("syslog.Dial() failed: %v", err)
                        }
                        _, err = io.WriteString(l, test.msg)
                        if err != nil {
-                               t.Fatalf("WriteString() failed: %s", err)
+                               t.Fatalf("WriteString() failed: %v", err)
                        }
                        rcvd := <-done
                        test.exp = fmt.Sprintf("<%d>", test.pri) + test.exp
                        var parsedHostname, timestamp string
                        var pid int
-                       if n, err := fmt.Sscanf(rcvd, test.exp, &timestamp, &parsedHostname,
-                               &pid); n != 3 || err != nil || hostname != parsedHostname {
-                               t.Fatalf("'%q', didn't match '%q' (%d %s)", rcvd, test.exp,
-                                       n, err)
+                       if n, err := fmt.Sscanf(rcvd, test.exp, &timestamp, &parsedHostname, &pid); n != 3 || err != nil || hostname != parsedHostname {
+                               t.Errorf("s.Info() = '%q', didn't match '%q' (%d %s)", rcvd, test.exp, n, err)
                        }
                }
        }
 }
+
+func TestConcurrentWrite(t *testing.T) {
+       addr := startServer("udp", "", make(chan string))
+       w, err := Dial("udp", addr, LOG_USER|LOG_ERR, "how's it going?")
+       if err != nil {
+               t.Fatalf("syslog.Dial() failed: %v", err)
+       }
+       var wg sync.WaitGroup
+       for i := 0; i < 10; i++ {
+               wg.Add(1)
+               go func() {
+                       err := w.Info("test")
+                       if err != nil {
+                               t.Errorf("Info() failed: %v", err)
+                               return
+                       }
+                       wg.Done()
+               }()
+       }
+       wg.Wait()
+}
+
+func TestConcurrentReconnect(t *testing.T) {
+       crashy = true
+       defer func() { crashy = false }()
+
+       net := "unix"
+       done := make(chan string)
+       addr := startServer(net, "", done)
+
+       // count all the messages arriving
+       count := make(chan int)
+       go func() {
+               ct := 0
+               for _ = range done {
+                       ct++
+                       // we are looking for 500 out of 1000 events
+                       // here because lots of log messages are lost
+                       // in buffers (kernel and/or bufio)
+                       if ct > 500 {
+                               break
+                       }
+               }
+               count <- ct
+       }()
+
+       var wg sync.WaitGroup
+       for i := 0; i < 10; i++ {
+               wg.Add(1)
+               go func() {
+                       w, err := Dial(net, addr, LOG_USER|LOG_ERR, "tag")
+                       if err != nil {
+                               t.Fatalf("syslog.Dial() failed: %v", err)
+                       }
+                       for i := 0; i < 100; i++ {
+                               err := w.Info("test")
+                               if err != nil {
+                                       t.Errorf("Info() failed: %v", err)
+                                       return
+                               }
+                       }
+                       wg.Done()
+               }()
+       }
+       wg.Wait()
+       close(done)
+
+       select {
+       case <-count:
+       case <-time.After(100 * time.Millisecond):
+               t.Error("timeout in concurrent reconnect")
+       }
+}
index 46a164dd5773b07b4cb53c7a9d325a8f78966e63..a0001ccaea9bfc7be63f202c05f4b3caaf7bc279 100644 (file)
@@ -14,18 +14,16 @@ import (
 // unixSyslog opens a connection to the syslog daemon running on the
 // local machine using a Unix domain socket.
 
-func unixSyslog() (conn serverConn, err error) {
+func unixSyslog() (conn net.Conn, err error) {
        logTypes := []string{"unixgram", "unix"}
        logPaths := []string{"/dev/log", "/var/run/syslog"}
-       var raddr string
        for _, network := range logTypes {
                for _, path := range logPaths {
-                       raddr = path
-                       conn, err := net.Dial(network, raddr)
+                       conn, err := net.Dial(network, path)
                        if err != nil {
                                continue
                        } else {
-                               return netConn{conn}, nil
+                               return conn, nil
                        }
                }
        }