Newer
Older
import (
"database/sql"
"github.com/jmoiron/sqlx"
"github.com/oidc-mytoken/api/v0"
log "github.com/sirupsen/logrus"
"github.com/oidc-mytoken/server/internal/db/dbrepo/encryptionkeyrepo"
"github.com/oidc-mytoken/server/internal/utils/hashUtils"
"github.com/oidc-mytoken/server/shared/mytoken/pkg/mtid"
// ParseError parses the passed error for a sql.ErrNoRows
func ParseError(err error) (bool, error) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
// recursiveRevokeMT revokes the passed mytoken as well as all children
func recursiveRevokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) error {
rlog, tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`CALL MTokens_RevokeRec(?)`, id)
return errors.WithStack(err)
},
)
// CheckTokenRevoked checks if a Mytoken has been revoked. If it is a rotating mytoken and auto_revoke is enabled for
// this token, it might get triggered.
func CheckTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64, rot *api.Rotation) (
revoked bool, err error,
) {
err = db.RunWithinTransaction(
rlog, tx, func(tx *sqlx.Tx) error {
if rot == nil {
revoked, err = checkTokenRevoked(rlog, tx, id, seqno)
return err
}
if rot.Lifetime > 0 {
revoked, err = checkRotatingTokenRevoked(rlog, tx, id, seqno, rot.Lifetime)
} else {
revoked, err = checkTokenRevoked(rlog, tx, id, seqno)
}
if err != nil {
return err
}
if !revoked || !rot.AutoRevoke {
return nil
}
// At this point we know, that the token is not valid, we now check if it is not valid because of the seqno.
idFound, err := checkTokenID(rlog, tx, id)
if err != nil || !idFound {
return err
}
return RevokeMT(rlog, tx, id, true)
},
)
return
func checkTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64) (bool, error) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&count, `CALL MTokens_Check(?,?)`, id, seqno))
},
); err != nil {
return true, err
}
if count > 0 { // token was found as Mytoken
return false, nil
}
return true, nil
}
func checkTokenID(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) (bool, error) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&count, `CALL MTokens_CheckID(?)`, id))
},
); err != nil {
return false, err
}
return count > 0, nil
}
func checkRotatingTokenRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno, rotationLifetime uint64) (
bool, error,
) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&count, `CALL MTokens_CheckRotating(?,?,?)`, id, seqno, rotationLifetime))
},
); err != nil {
if count > 0 { // token was found as Mytoken
// UpdateSeqNo updates the sequence number of a mytoken, i.e. it rotates the mytoken. Don't forget to update the
// encryption key
func UpdateSeqNo(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, seqno uint64) error {
rlog, tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`CALL MTokens_UpdateSeqNo(?,?)`, id, seqno)
return errors.WithStack(err)
},
)
// revokeMT revokes the passed mytoken but no children
func revokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) error {
rlog, tx, func(tx *sqlx.Tx) error {
if err := encryptionkeyrepo.DeleteEncryptionKey(rlog, tx, id); err != nil {
return err
}
_, err := tx.Exec(`CALL MTokens_Delete(?)`, id)
return errors.WithStack(err)
},
)
// RevokeMT revokes the passed mytoken and depending on the recursive parameter also its children
func RevokeMT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID, recursive bool) error {
return recursiveRevokeMT(rlog, tx, id)
return revokeMT(rlog, tx, id)
// GetTokenUsagesAT returns how often a Mytoken was used with a specific restriction to obtain an access token
func GetTokenUsagesAT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, restrictionHash string) (
usages *int64, err error,
) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&usageCount, `CALL TokenUsages_GetAT(?,?)`, restrictionHash, myID))
},
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// GetTokenUsagesOther returns how often a Mytoken was used with a specific restriction to do something else than
// obtaining an access token
func GetTokenUsagesOther(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, restrictionHash string) (
usages *int64, err error,
) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&usageCount, `CALL TokenUsages_GetOther(?,?)`, restrictionHash, myID))
},
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// IncreaseTokenUsageAT increases the usage count for obtaining ATs with a Mytoken and the given restriction
func IncreaseTokenUsageAT(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, jsonRestriction []byte) error {
rlog, tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(
`CALL TokenUsages_IncrAT(?,?,?)`,
myID, jsonRestriction, hashUtils.SHA512Str(jsonRestriction),
)
return errors.WithStack(err)
},
)
// IncreaseTokenUsageOther increases the usage count for other usages with a Mytoken and the given restriction
func IncreaseTokenUsageOther(rlog log.Ext1FieldLogger, tx *sqlx.Tx, myID mtid.MTID, jsonRestriction []byte) error {
rlog, tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(
`CALL TokenUsages_IncrOther(?,?,?)`,
myID, jsonRestriction, hashUtils.SHA512Str(jsonRestriction),
)
return errors.WithStack(err)
},
)
}
// GetMTName returns the name of the mytoken
func GetMTName(rlog log.Ext1FieldLogger, tx *sqlx.Tx, id mtid.MTID) (name db.NullString, err error) {
rlog, tx, func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&name, `CALL MTokens_GetName(?)`, id))
},
)
return