Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
package supertokenrepohelper
import (
"database/sql"
"errors"
"github.com/jmoiron/sqlx"
uuid "github.com/satori/go.uuid"
"github.com/zachmann/mytoken/internal/db"
)
// UpdateRefreshToken updates a refresh token in the database, all occurrences of the RT are updated.
func UpdateRefreshToken(tx *sqlx.Tx, oldRT, newRT string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`UPDATE SuperTokens SET refresh_token=? WHERE refresh_token=?`, newRT, oldRT)
return err
})
}
// StoreShortSuperToken stores a short super token linked to the id of a SuperToken
func StoreShortSuperToken(tx *sqlx.Tx, shortToken string, stid uuid.UUID) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO ShortSuperTokens (short_token, ST_id) VALUES(?,?)`, shortToken, stid)
return err
})
}
// GetRefreshToken returns the refresh token for a super token id
func GetRefreshToken(stid uuid.UUID) (string, bool, error) {
var rt string
err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE id=?`, stid)
return parseStringResult(rt, err)
}
// GetRefreshTokenByTokenString returns the refresh token for a super token jwt string
func GetRefreshTokenByTokenString(token string) (string, bool, error) {
var rt string
err := db.DB().Get(&rt, `SELECT refresh_token FROM SuperTokens WHERE token=?`, token)
return parseStringResult(rt, err)
}
func parseStringResult(res string, err error) (string, bool, error) {
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return res, true, nil
}
// GetSTParentID returns the id of the parent super token of the passed super token id
func GetSTParentID(stid uuid.UUID) (string, bool, error) {
var parentID sql.NullString
if err := db.DB().Get(&parentID, `SELECT parent_id FROM SuperTokens WHERE id=?`, stid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return parentID.String, true, nil
}
// GetSTRootID returns the id of the root super token of the passed super token id
func GetSTRootID(stid uuid.UUID) (string, bool, error) {
var rootID sql.NullString
if err := db.DB().Get(&rootID, `SELECT root_id FROM SuperTokens WHERE id=?`, stid); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", false, nil
} else {
return "", false, err
}
}
return rootID.String, true, nil
}
// RecursiveRevokeSTByTokenString revokes the passed super token as well as all children
func RecursiveRevokeSTByTokenString(tx *sqlx.Tx, token string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`
DELETE FROM SuperTokens WHERE id=ANY(
WITH Recursive childs
AS
(
SELECT id, parent_id FROM SuperTokens WHERE token=?
UNION ALL
SELECT st.id, st.parent_id FROM SuperTokens st INNER JOIN childs c WHERE st.parent_id=c.id
)
SELECT id
FROM childs
)`, token)
return err
})
}
// CheckTokenRevoked takes a short super token or a normal super token and checks if it was revoked. If the token is found in the db, the super token string will be returned.
// Therefore, this function can also be used to exchange a short super token into a normal one.
func CheckTokenRevoked(token string) (string, bool, error) {
var count int
if err := db.DB().Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE token=?`, token); err != nil {
return token, true, err
}
if count > 0 { // token was found as SuperToken
return token, false, nil
}
var superToken string
if err := db.DB().Get(&superToken, `SELECT token FROM ShortSuperTokensV WHERE short_token=?`, token); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
return token, true, err
}
}
return superToken, false, nil
}
// RevokeSTByTokenString revokes the passed super token but no children
func RevokeSTByTokenString(tx *sqlx.Tx, token string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`DELETE FROM SuperTokens WHERE token=?`, token)
return err
})
}
// RevokeSTByToken revokes the passed super token and depending on the recursive parameter also its children
func RevokeSTByToken(tx *sqlx.Tx, token string, recursive bool) error {
if recursive {
return RecursiveRevokeSTByTokenString(tx, token)
} else {
return RevokeSTByTokenString(tx, token)
}
}
// CountRTOccurrences counts how many SuperTokens use the passed refresh token
func CountRTOccurrences(tx *sqlx.Tx, rt string) (count int, err error) {
err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
err = tx.Get(&count, `SELECT COUNT(1) FROM SuperTokens WHERE refresh_token=?`, rt)
return err
})
return
}
// GetTokenUsagesAT returns how often a SuperToken was used with a specific restriction to obtain an access token
func GetTokenUsagesAT(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
var usageCount int64
if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
return tx.Get(&usageCount, `SELECT usages_AT FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
}); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// GetTokenUsagesOther returns how often a SuperToken was used with a specific restriction to do something else than obtaining an access token
func GetTokenUsagesOther(tx *sqlx.Tx, stid uuid.UUID, restrictionHash string) (usages *int64, err error) {
var usageCount int64
if err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
return tx.Get(&usageCount, `SELECT usages_other FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, restrictionHash, stid)
}); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// No usage entry -> was not used before -> usages=nil
err = nil // This is fine
return
}
return
}
usages = &usageCount
return
}
// IncreaseTokenUsageAT increases the usage count for obtaining ATs with a SuperToken and the given restriction
func IncreaseTokenUsageAT(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_AT) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_AT = usages_AT + 1`, stid, jsonRestriction)
return err
})
}
// IncreaseTokenUsageOther increases the usage count for other usages with a SuperToken and the given restriction
func IncreaseTokenUsageOther(tx *sqlx.Tx, stid uuid.UUID, jsonRestriction []byte) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_other) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_other = usages_other + 1`, stid, jsonRestriction)
return err
})
}