package mytokenrepohelper import ( "database/sql" "github.com/jmoiron/sqlx" "github.com/oidc-mytoken/api/v0" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "github.com/oidc-mytoken/server/internal/db" "github.com/oidc-mytoken/server/internal/db/dbrepo/encryptionkeyrepo" "github.com/oidc-mytoken/server/internal/utils/hashUtils" "github.com/oidc-mytoken/server/shared/mytoken/pkg/mtid" ) // ParseError parses the passed error for a sql.ErrNoRows func ParseError(err error) (bool, error) { if err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil } else { return false, err } } return true, nil } // recursiveRevokeMT revokes the passed mytoken as well as all children func recursiveRevokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) error { return db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { _, err := tx.Exec(`CALL MTokens_RevokeRec(?)`, id) return errors.WithStack(err) }, ) } // CheckTokenRevoked checks if a Mytoken has been revoked. If it is a rotating mytoken and auto_revoke is enabled for // this token, it might get triggered. func CheckTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64, rot *api.Rotation) ( revoked bool, err error, ) { err = db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { if rot == nil { revoked, err = checkTokenRevoked(rlog, tx, id, seqno) return err } if rot.Lifetime > 0 { revoked, err = checkRotatingTokenRevoked(rlog, tx, id, seqno, rot.Lifetime) } else { revoked, err = checkTokenRevoked(rlog, tx, id, seqno) } if err != nil { return err } if !revoked || !rot.AutoRevoke { return nil } // At this point we know, that the token is not valid, we now check if it is not valid because of the seqno. idFound, err := checkTokenID(rlog, tx, id) if err != nil || !idFound { return err } return RevokeMT(rlog, tx, id, true) }, ) return } func checkTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64) (bool, error) { var count int if err := db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&count, `CALL MTokens_Check(?,?)`, id, seqno)) }, ); err != nil { return true, err } if count > 0 { // token was found as Mytoken return false, nil } return true, nil } func checkTokenID(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) (bool, error) { var count int if err := db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&count, `CALL MTokens_CheckID(?)`, id)) }, ); err != nil { return false, err } return count > 0, nil } func checkRotatingTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno, rotationLifetime uint64) ( bool, error, ) { var count int if err := db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&count, `CALL MTokens_CheckRotating(?,?,?)`, id, seqno, rotationLifetime)) }, ); err != nil { return true, err } if count > 0 { // token was found as Mytoken return false, nil } return true, nil } // UpdateSeqNo updates the sequence number of a mytoken, i.e. it rotates the mytoken. Don't forget to update the // encryption key func UpdateSeqNo(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64) error { return db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { _, err := tx.Exec(`CALL MTokens_UpdateSeqNo(?,?)`, id, seqno) return errors.WithStack(err) }, ) } // revokeMT revokes the passed mytoken but no children func revokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) error { return db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { if err := encryptionkeyrepo.DeleteEncryptionKey(rlog, tx, id); err != nil { return err } _, err := tx.Exec(`CALL MTokens_Delete(?)`, id) return errors.WithStack(err) }, ) } // RevokeMT revokes the passed mytoken and depending on the recursive parameter also its children func RevokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, recursive bool) error { if recursive { return recursiveRevokeMT(rlog, tx, id) } else { return revokeMT(rlog, tx, id) } } // GetTokenUsagesAT returns how often a Mytoken was used with a specific restriction to obtain an access token func GetTokenUsagesAT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, restrictionHash string) ( usages *int64, err error, ) { var usageCount int64 if err = db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&usageCount, `CALL TokenUsages_GetAT(?,?)`, restrictionHash, myID)) }, ); err != nil { if errors.Is(err, sql.ErrNoRows) { // No usage entry -> was not used before -> usages=nil err = nil // This is fine return } return } usages = &usageCount return } // GetTokenUsagesOther returns how often a Mytoken was used with a specific restriction to do something else than // obtaining an access token func GetTokenUsagesOther(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, restrictionHash string) ( usages *int64, err error, ) { var usageCount int64 if err = db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&usageCount, `CALL TokenUsages_GetOther(?,?)`, restrictionHash, myID)) }, ); err != nil { if errors.Is(err, sql.ErrNoRows) { // No usage entry -> was not used before -> usages=nil err = nil // This is fine return } return } usages = &usageCount return } // IncreaseTokenUsageAT increases the usage count for obtaining ATs with a Mytoken and the given restriction func IncreaseTokenUsageAT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, jsonRestriction []byte) error { return db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { _, err := tx.Exec( `CALL TokenUsages_IncrAT(?,?,?)`, myID, jsonRestriction, hashUtils.SHA512Str(jsonRestriction), ) return errors.WithStack(err) }, ) } // IncreaseTokenUsageOther increases the usage count for other usages with a Mytoken and the given restriction func IncreaseTokenUsageOther(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, jsonRestriction []byte) error { return db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { _, err := tx.Exec( `CALL TokenUsages_IncrOther(?,?,?)`, myID, jsonRestriction, hashUtils.SHA512Str(jsonRestriction), ) return errors.WithStack(err) }, ) } // GetMTName returns the name of the mytoken func GetMTName(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) (name db.NullString, err error) { err = db.RunWithinTransaction( rlog, tx, func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&name, `CALL MTokens_GetName(?)`, id)) }, ) return }