package cluster import ( "fmt" "strings" "time" "github.com/jmoiron/sqlx" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/oidc-mytoken/server/internal/config" "github.com/oidc-mytoken/server/internal/utils/errorfmt" // mysql driver _ "github.com/go-sql-driver/mysql" ) // NewFromConfig creates a new Cluster from the passed config.DBConf func NewFromConfig(conf config.DBConf) *Cluster { c := newCluster(len(conf.Hosts)) c.conf = &conf c.startReconnector() c.AddNodes() log.Debug("Created db cluster") return c } func newCluster(size int) *Cluster { c := &Cluster{ active: make(chan *node, size), down: make(chan *node, size), stop: make(chan interface{}), } return c } // Cluster is a type for holding a db cluster type Cluster struct { active chan *node down chan *node stop chan interface{} conf *config.DBConf } type node struct { db *sqlx.DB host string active bool // lock sync.RWMutex } func (n *node) close() { if n.db != nil { err := n.db.Close() log.Errorf("%s", errorfmt.Full(err)) n.db = nil } } // AddNodes adds the nodes specified for this Cluster to the cluster func (c *Cluster) AddNodes() { for _, host := range c.conf.Hosts { if err := c.AddNode(host); err != nil { log.Errorf("%s", errorfmt.Full(err)) } } } // AddNode adds the passed host a a db node to the cluster func (c *Cluster) AddNode(host string) error { log.WithField("host", host).Debug("Adding node to db cluster") return c.addNode( &node{ host: host, }, ) } func (c *Cluster) addNode(n *node) error { n.close() dsn := fmt.Sprintf("%s:%s@%s(%s)/%s?parseTime=true", c.conf.User, c.conf.GetPassword(), "tcp", n.host, c.conf.DB) db, err := connectDSN(dsn) if err != nil { n.active = false c.down <- n log.WithField("dsn", dsn).Debug("Could not connect node") return err } n.db = db n.active = true c.active <- n return nil } func (c *Cluster) startReconnector() { go func() { for { select { case <-c.stop: log.Debug("Stopping re-connector") return default: log.Debug("Run checkNodesDown") if c.checkNodesDown() { log.Debug("Stopping re-connector") return } conf := c.conf if conf == nil { conf = &config.Get().DB } time.Sleep(time.Duration(conf.ReconnectInterval) * time.Second) } } }() } func (c *Cluster) checkNodesDown() bool { var n *node select { case <-c.stop: return true case n = <-c.down: // blocks until at least one node is down break } l := len(c.down) _ = c.addNode(n) for i := 0; i < l; i++ { // check the reminding nodes n = <-c.down _ = c.addNode(n) } return false } // Close closes the cluster func (c *Cluster) Close() { c.stop <- struct{}{} for { select { case active := <-c.active: active.close() case inactive := <-c.down: inactive.close() default: return } } } func connectDSN(dsn string) (*sqlx.DB, error) { db, err := sqlx.Connect("mysql", dsn) if err != nil { return nil, errors.WithStack(err) } db.SetConnMaxLifetime(time.Minute * 4) db.SetMaxOpenConns(10) db.SetMaxIdleConns(10) return db, nil } // Transact does a database transaction for the passed function func (c *Cluster) Transact(rlog log.Ext1FieldLogger, fn func(*sqlx.Tx) error) error { for { n := c.next(rlog) if n == nil { return errors.New("no db node available") } closed, err := n.transact(rlog, fn) if !closed { return err } rlog.Errorf("%s", errorfmt.Full(err)) n.active = false } } func (n *node) transact(rlog log.Ext1FieldLogger, fn func(*sqlx.Tx) error) (bool, error) { err := n.trans(rlog, fn) if err != nil { e := errorfmt.Error(err) switch { case e == "sql: database is closed", strings.HasPrefix(e, "dial tcp"), strings.HasSuffix(e, "closing bad idle connection: EOF"): rlog.WithField("host", n.host).Error("Node is down") return true, err } } return false, err } func (n *node) trans(rlog log.Ext1FieldLogger, fn func(*sqlx.Tx) error) error { tx, err := n.db.Beginx() if err != nil { return errors.WithStack(err) } err = fn(tx) if err != nil { if e := tx.Rollback(); e != nil { rlog.Errorf("%s", errorfmt.Full(e)) } return err } return errors.WithStack(tx.Commit()) } func (c *Cluster) next(rlog log.Ext1FieldLogger) *node { rlog.Trace("Selecting a node") select { case n := <-c.active: if n.active { c.active <- n rlog.WithField("host", n.host).Trace("Selected active node") return n } rlog.WithField("host", n.host).Trace("Found inactive node") go c.addNode(n) // try to add node again, if it does not work, will add to down nodes return c.next(rlog) default: rlog.Debug("No active nodes") return nil } } // RunWithinTransaction runs the passed function using the passed transaction; if nil is passed as tx a new transaction // is created. This is basically a wrapper function, that works with a possible nil-tx func (c *Cluster) RunWithinTransaction(rlog log.Ext1FieldLogger, tx *sqlx.Tx, fn func(*sqlx.Tx) error) error { if tx == nil { return c.Transact(rlog, fn) } else { return fn(tx) } }