-
Gabriel Zachmann authoredGabriel Zachmann authored
helpers.go 6.80 KiB
package supertokenrepohelper
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
uuid "github.com/satori/go.uuid"
"github.com/zachmann/mytoken/internal/db"
)
// UpdateRefreshToken updates a refresh token in the database, all occurrences of the RT are updated.
func UpdateRefreshToken(tx *sqlx.Tx, oldRT, newRT string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`UPDATE SuperTokens SET refresh_token=? WHERE refresh_token=?`, newRT, oldRT)
return err
})
}
// StoreShortSuperToken stores a short super token linked to the id of a SuperToken
func StoreShortSuperToken(tx *sqlx.Tx, shortToken string, stid uuid.UUID) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO ShortSuperTokens (short_token, ST_id) VALUES(?,?)`, shortToken, stid)
return err
})
}
// GetRefreshToken returns the refresh token for a super token id
func GetRefreshToken(stid uuid.UUID) (string, bool, error) {
var rt string
err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE id=?`, stid)
return parseStringResult(rt, err)
}
// GetRefreshTokenByTokenString returns the refresh token for a super token jwt string
func GetRefreshTokenByTokenString(token string) (string, bool, error) {
var rt string
err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE token=?`, token)
return parseStringResult(rt, err)
}
func parseStringResult(res string, err error) (string, bool, error) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return res, true, nil
}
// GetSTParentID returns the id of the parent super token of the passed super token id
func GetSTParentID(stid uuid.UUID) (string, bool, error) {
var parentID sql.NullString
if err := db.DB().Get(&parentID, `SELECT parent_id FROM SuperTokens WHERE id=?`, stid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return parentID.String, true, nil
}
// GetSTRootID returns the id of the root super token of the passed super token id
func GetSTRootID(stid uuid.UUID) (string, bool, error) {
var rootID sql.NullString
if err := db.DB().Get(&rootID, `SELECT root_id FROM SuperTokens WHERE id=?`, stid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return rootID.String, true, nil
}
// RecursiveRevokeSTByTokenString revokes the passed super token as well as all children
func RecursiveRevokeSTByTokenString(tx *sqlx.Tx, token string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`
DELETE FROM SuperTokens WHERE id=ANY(
WITH Recursive childs
AS
(
SELECT id, parent_id FROM SuperTokens WHERE token=?
UNION ALL
SELECT st.id, st.parent_id FROM SuperTokens st INNER JOIN childs c WHERE st.parent_id=c.id
)
SELECT id
FROM childs
)`, token)
return err
})
}
// CheckTokenRevoked takes a short super token or a normal super token and checks if it was revoked. If the token is found in the db, the super token string will be returned.
// Therefore, this function can also be used to exchange a short super token into a normal one.
func CheckTokenRevoked(token string) (string, bool, error) {
var count int
if err := db.DB().Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE token=?`, token); err != nil {
return token, true, err
}
if count > 0 { // token was found as SuperToken
return token, false, nil
}
var superToken string
if err := db.DB().Get(&superToken, `SELECT token FROM ShortSuperTokensV WHERE short_token=?`, token); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return token, true, err
}
}
return superToken, false, nil
}
// RevokeSTByTokenString revokes the passed super token but no children
func RevokeSTByTokenString(tx *sqlx.Tx, token string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`DELETE FROM SuperTokens WHERE token=?`, token)
return err
})
}
// RevokeSTByToken revokes the passed super token and depending on the recursive parameter also its children
func RevokeSTByToken(tx *sqlx.Tx, token string, recursive bool) error {
if recursive {
return RecursiveRevokeSTByTokenString(tx, token)
} else {
return RevokeSTByTokenString(tx, token)
}
}
// CountRTOccurrences counts how many SuperTokens use the passed refresh token
func CountRTOccurrences(tx *sqlx.Tx, rt string) (count int, err error) {
err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
err = tx.Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE refresh_token=?`, rt)
return err
})
return
}
// GetTokenUsagesAT returns how often a SuperToken was used with a specific restriction to obtain an access token
func GetTokenUsagesAT(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
var usageCount int64
if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
return tx.Get(&usageCount, `SELECT usages_AT FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
}); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// GetTokenUsagesOther returns how often a SuperToken was used with a specific restriction to do something else than obtaining an access token
func GetTokenUsagesOther(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
var usageCount int64
if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
return tx.Get(&usageCount, `SELECT usages_other FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
}); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// IncreaseTokenUsageAT increases the usage count for obtaining ATs with a SuperToken and the given restriction
func IncreaseTokenUsageAT(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_AT) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_AT = usages_AT + 1`, stid, jsonRestriction)
return err
})
}
// IncreaseTokenUsageOther increases the usage count for other usages with a SuperToken and the given restriction
func IncreaseTokenUsageOther(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_other) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_other = usages_other + 1`, stid, jsonRestriction)
return err
})
}