Newer
Older
package logger
import (
"bytes"
"io"
"path/filepath"
"github.com/gofiber/fiber/v2"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/writer"
"github.com/oidc-mytoken/server/internal/config"
)
*log.Entry
rootHook *rootHook
ctx smartLoggerContext
}
type smartLoggerContext struct {
buffer *bytes.Buffer
id string
}
type rootHook struct {
buffer log.Hook
error *errorHook
}
func (*rootHook) Levels() []log.Level {
return log.AllLevels
}
func (h *rootHook) Fire(e *log.Entry) error {
if !h.error.firedBefore {
if err := h.buffer.Fire(e); err != nil {
return err
}
}
if h.error.firedBefore || log.ErrorLevel >= e.Level {
if err := h.error.Fire(e); err != nil {
return err
}
}
return nil
}
type errorHook struct {
*smartLoggerContext
firedBefore bool
file io.Writer
}
func (*errorHook) Levels() []log.Level {
return log.AllLevels // we must be triggered at
}
func (h *errorHook) Fire(e *log.Entry) (err error) {
var logData []byte
if h.firedBefore {
logData, err = e.Bytes()
if err != nil {
return
}
} else {
logData = h.smartLoggerContext.buffer.Bytes()
// from now on we will log all future log messages directly to file (if there are any)
h.firedBefore = true
h.smartLoggerContext.buffer.Reset()
file, errr := h.getFile()
if errr != nil {
return errr
if _, err = file.Write(logData); err != nil {
return
}
return
func (h *errorHook) getFile() (io.Writer, error) {
var err error
if h.file == nil {
h.file, err = getFile(filepath.Join(config.Get().Logging.Internal.Smart.Dir, h.smartLoggerContext.id))
}
return h.file, err
}
func newErrorHook(ctx *smartLoggerContext) *errorHook {
return &errorHook{
smartLoggerContext: ctx,
}
}
func newBufferHook(ctx *smartLoggerContext) log.Hook {
return &writer.Hook{
Writer: ctx.buffer,
LogLevels: log.AllLevels,
}
}
func newRootHook(ctx *smartLoggerContext) *rootHook {
error: newErrorHook(ctx),
}
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
}
func smartPrepareLogger(rootH *rootHook) *log.Logger {
std := log.StandardLogger()
logger := &log.Logger{
Out: std.Out,
Hooks: make(log.LevelHooks),
Formatter: std.Formatter,
ReportCaller: std.ReportCaller,
Level: std.Level,
ExitFunc: std.ExitFunc,
}
for l, hs := range std.Hooks {
logger.Hooks[l] = append([]log.Hook{}, hs...)
}
logger.Hooks.Add(rootH)
return logger
}
func getLogEntry(id string, logger *log.Logger) *log.Entry {
return logger.WithField("requestid", id)
}
func getIDlogger(id string) log.Ext1FieldLogger {
if !config.Get().Logging.Internal.Smart.Enabled {
return getLogEntry(id, log.StandardLogger())
}
ctx: smartLoggerContext{
buffer: new(bytes.Buffer),
id: id,
},
}
smartLog.rootHook = newRootHook(&smartLog.ctx)
logger := smartPrepareLogger(smartLog.rootHook)
smartLog.Entry = getLogEntry(id, logger)
return smartLog
}
// GetRequestLogger returns a logrus.Ext1FieldLogger that always includes a request's id
func GetRequestLogger(ctx *fiber.Ctx) log.Ext1FieldLogger {
rid := ctx.Locals("requestid")
if rid != nil {
return getIDlogger(rid.(string))
}
return getIDlogger("")
}
// GetSSHRequestLogger returns a logrus.Ext1FieldLogger that always includes an ssh request's id
func GetSSHRequestLogger(sessionID string) log.Ext1FieldLogger {
return getIDlogger(sessionID)