Skip to content
Snippets Groups Projects
helpers.go 6.80 KiB
package supertokenrepohelper

import (
	"database/sql"
	"errors"

	"github.com/jmoiron/sqlx"
	uuid "github.com/satori/go.uuid"

	"github.com/zachmann/mytoken/internal/db"
)

// UpdateRefreshToken updates a refresh token in the database, all occurrences of the RT are updated.
func UpdateRefreshToken(tx *sqlx.Tx, oldRT, newRT string) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`UPDATE SuperTokens SET refresh_token=? WHERE refresh_token=?`, newRT, oldRT)
		return err
	})
}

// StoreShortSuperToken stores a short super token linked to the id of a SuperToken
func StoreShortSuperToken(tx *sqlx.Tx, shortToken string, stid uuid.UUID) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`INSERT INTO ShortSuperTokens (short_token, ST_id) VALUES(?,?)`, shortToken, stid)
		return err
	})
}

// GetRefreshToken returns the refresh token for a super token id
func GetRefreshToken(stid uuid.UUID) (string, bool, error) {
	var rt string
	err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE id=?`, stid)
	return parseStringResult(rt, err)
}

// GetRefreshTokenByTokenString returns the refresh token for a super token jwt string
func GetRefreshTokenByTokenString(token string) (string, bool, error) {
	var rt string
	err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE token=?`, token)
	return parseStringResult(rt, err)
}

func parseStringResult(res string, err error) (string, bool, error) {
	if err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return "", false, nil
		} else {
			return "", false, err
		}
	}
	return res, true, nil
}

// GetSTParentID returns the id of the parent super token of the passed super token id
func GetSTParentID(stid uuid.UUID) (string, bool, error) {
	var parentID sql.NullString
	if err := db.DB().Get(&parentID, `SELECT parent_id FROM SuperTokens WHERE id=?`, stid); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return "", false, nil
		} else {
			return "", false, err
		}
	}
	return parentID.String, true, nil
}

// GetSTRootID returns the id of the root super token of the passed super token id
func GetSTRootID(stid uuid.UUID) (string, bool, error) {
	var rootID sql.NullString
	if err := db.DB().Get(&rootID, `SELECT root_id FROM SuperTokens WHERE id=?`, stid); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			return "", false, nil
		} else {
			return "", false, err
		}
	}
	return rootID.String, true, nil
}

// RecursiveRevokeSTByTokenString revokes the passed super token as well as all children
func RecursiveRevokeSTByTokenString(tx *sqlx.Tx, token string) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`
			DELETE FROM SuperTokens WHERE id=ANY(
			WITH Recursive childs
			AS
			(
				SELECT id, parent_id FROM SuperTokens WHERE token=?
				UNION ALL
				SELECT st.id, st.parent_id FROM SuperTokens st INNER JOIN childs c WHERE st.parent_id=c.id
			)
			SELECT id
			FROM   childs
			)`, token)
		return err
	})
}

// CheckTokenRevoked takes a short super token or a normal super token and checks if it was revoked. If the token is found in the db, the super token string will be returned.
// Therefore, this function can also be used to exchange a short super token into a normal one.
func CheckTokenRevoked(token string) (string, bool, error) {
	var count int
	if err := db.DB().Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE token=?`, token); err != nil {
		return token, true, err
	}
	if count > 0 { // token was found as SuperToken
		return token, false, nil
	}
	var superToken string
	if err := db.DB().Get(&superToken, `SELECT token FROM ShortSuperTokensV WHERE short_token=?`, token); err != nil {
		if !errors.Is(err, sql.ErrNoRows) {
			return token, true, err
		}
	}
	return superToken, false, nil
}

// RevokeSTByTokenString revokes the passed super token but no children
func RevokeSTByTokenString(tx *sqlx.Tx, token string) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`DELETE FROM SuperTokens WHERE token=?`, token)
		return err
	})
}

// RevokeSTByToken revokes the passed super token and depending on the recursive parameter also its children
func RevokeSTByToken(tx *sqlx.Tx, token string, recursive bool) error {
	if recursive {
		return RecursiveRevokeSTByTokenString(tx, token)
	} else {
		return RevokeSTByTokenString(tx, token)
	}
}

// CountRTOccurrences counts how many SuperTokens use the passed refresh token
func CountRTOccurrences(tx *sqlx.Tx, rt string) (count int, err error) {
	err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		err = tx.Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE refresh_token=?`, rt)
		return err
	})
	return
}

// GetTokenUsagesAT returns how often a SuperToken was used with a specific restriction to obtain an access token
func GetTokenUsagesAT(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
	var usageCount int64
	if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		return tx.Get(&usageCount, `SELECT usages_AT FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
	}); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			// No usage entry -> was not used before -> usages=nil
			err = nil // This is fine
			return
		}
		return
	}
	usages = &usageCount
	return
}

// GetTokenUsagesOther returns how often a SuperToken was used with a specific restriction to do something else than obtaining an access token
func GetTokenUsagesOther(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
	var usageCount int64
	if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		return tx.Get(&usageCount, `SELECT usages_other FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
	}); err != nil {
		if errors.Is(err, sql.ErrNoRows) {
			// No usage entry -> was not used before -> usages=nil
			err = nil // This is fine
			return
		}
		return
	}
	usages = &usageCount
	return
}

// IncreaseTokenUsageAT increases the usage count for obtaining ATs with a SuperToken and the given restriction
func IncreaseTokenUsageAT(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_AT) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_AT = usages_AT + 1`, stid, jsonRestriction)
		return err
	})
}

// IncreaseTokenUsageOther increases the usage count for other usages with a SuperToken and the given restriction
func IncreaseTokenUsageOther(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
	return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
		_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_other) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_other = usages_other + 1`, stid, jsonRestriction)
		return err
	})
}