-
Gabriel Zachmann authoredGabriel Zachmann authored
restriction.go 10.68 KiB
package restrictions
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"math"
"strings"
"time"
"github.com/zachmann/mytoken/internal/db"
"github.com/zachmann/mytoken/internal/utils/hashUtils"
"github.com/zachmann/mytoken/internal/utils/geoip"
uuid "github.com/satori/go.uuid"
"github.com/jinzhu/copier"
log "github.com/sirupsen/logrus"
"github.com/zachmann/mytoken/internal/utils"
)
// Restrictions is a slice of Restriction
type Restrictions []Restriction
// Restriction describes a token usage restriction
type Restriction struct {
NotBefore int64 `json:"nbf,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
Scope string `json:"scope,omitempty"`
Audiences []string `json:"audience,omitempty"`
IPs []string `json:"ip,omitempty"`
GeoIPWhite []string `json:"geoip_white,omitempty"`
GeoIPBlack []string `json:"geoip_black,omitempty"`
UsagesAT *int64 `json:"usages_AT,omitempty"`
UsagesOther *int64 `json:"usages_other,omitempty"`
//Usages *int64 `json:"usages,omitempty"`
}
func (r *Restriction) Hash() ([]byte, error) {
j, err := json.Marshal(r)
if err != nil {
return nil, err
}
return hashUtils.SHA512(j)
}
func (r *Restriction) VerifyTimeBased() bool {
log.Trace("Verifying time based")
now := time.Now().Unix()
return (now >= r.NotBefore) && (r.ExpiresAt == 0 ||
now <= r.ExpiresAt)
}
func (r *Restriction) VerifyIPBased(ip string) bool {
return r.verifyIPs(ip) && r.verifyGeoIP(ip)
}
func (r *Restriction) verifyIPs(ip string) bool {
log.Trace("Verifying ips")
return len(r.IPs) == 0 ||
utils.IPIsIn(ip, r.IPs)
}
func (r *Restriction) verifyGeoIP(ip string) bool {
log.Trace("Verifying ip geo location")
return r.verifyGeoIPBlack(ip) && r.verifyGeoIPWhite(ip)
}
func (r *Restriction) verifyGeoIPWhite(ip string) bool {
log.Trace("Verifying ip geo location white list")
white := r.GeoIPWhite
if len(white) == 0 {
return true
}
return utils.StringInSlice(geoip.CountryCode(ip), white)
}
func (r *Restriction) verifyGeoIPBlack(ip string) bool {
log.Trace("Verifying ip geo location black list")
black := r.GeoIPBlack
if len(black) == 0 {
return true
}
return !utils.StringInSlice(geoip.CountryCode(ip), black)
}
func (r *Restriction) VerifyATUsageCounts(stid uuid.UUID) bool {
log.Trace("Verifying AT usage count")
if r.UsagesAT == nil {
return true
}
hash, err := r.Hash()
if err != nil {
log.WithError(err).Error()
return false
}
var usages int64
if err := db.DB().Get(&usages, `SELECT usages_AT FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, string(hash), stid); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
log.WithError(err).Error()
return false
}
// No usage entry -> was not used before
log.WithField("stid", stid.String()).WithField("restriction_hash", string(hash)).Debug("Did not found restriction in database; it was not used before")
return *r.UsagesAT > 0
}
log.WithField("stid", stid.String()).WithField("restriction_hash", string(hash)).WithField("used", usages).WithField("usageLimit", *r.UsagesAT).Debug("Found in db.")
return usages < *r.UsagesAT
}
func (r *Restriction) VerifyOtherUsageCounts(stid uuid.UUID) bool {
if r.UsagesOther == nil {
return true
}
hash, err := r.Hash()
if err != nil {
log.WithError(err).Error()
return false
}
var usages int64
if err := db.DB().Get(&usages, `SELECT usages_other FROM TokenUsages WHERE restriction_hash=? AND ST_id=?`, string(hash), stid); err != nil {
if !errors.Is(err, sql.ErrNoRows) {
log.WithError(err).Error()
return false
}
// No usage entry -> was not used before
return *r.UsagesOther > 0
}
return usages < *r.UsagesOther
}
func (r *Restriction) verify(ip string) bool {
return r.VerifyTimeBased() &&
r.VerifyIPBased(ip)
}
func (r *Restriction) VerifyAT(ip string, stid uuid.UUID) bool {
return r.verify(ip) && r.VerifyATUsageCounts(stid)
}
func (r *Restriction) VerifyOther(ip string, stid uuid.UUID) bool {
return r.verify(ip) &&
r.VerifyOtherUsageCounts(stid)
}
func (r *Restriction) UsedAT(stid uuid.UUID) error {
js, err := json.Marshal(r)
if err != nil {
return err
}
_, err = db.DB().Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_AT) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_AT = usages_AT + 1`, stid, js)
return err
}
func (r *Restriction) UsedOther(stid uuid.UUID) error {
js, err := json.Marshal(r)
if err != nil {
return err
}
_, err = db.DB().Exec(`INSERT INTO TokenUsages (ST_id, restriction, usages_other) VALUES (?, ?, 1) ON DUPLICATE KEY UPDATE usages_other = usages_other + 1`, stid, js)
return err
}
func (r Restrictions) VerifyForAT(ip string, stid uuid.UUID) bool {
if len(r) == 0 {
return true
}
return len(r.GetValidForAT(ip, stid)) > 0
}
func (r Restrictions) VerifyForOther(ip string, stid uuid.UUID) bool {
if len(r) == 0 {
return true
}
return len(r.GetValidForOther(ip, stid)) > 0
}
func (r Restrictions) GetValidForAT(ip string, stid uuid.UUID) (ret Restrictions) {
for _, rr := range r {
if rr.VerifyAT(ip, stid) {
log.Trace("Found a valid restriction")
ret = append(ret, rr)
}
}
return
}
func (r Restrictions) GetValidForOther(ip string, stid uuid.UUID) (ret Restrictions) {
for _, rr := range r {
if rr.VerifyOther(ip, stid) {
ret = append(ret, rr)
}
}
return
}
func (r Restrictions) WithScopes(scopes []string) (ret Restrictions) {
log.WithField("scopes", scopes).WithField("len", len(scopes)).Trace("Filter restrictions for scopes")
if len(scopes) == 0 {
log.Trace("scopes empty, returning all restrictions")
return r
}
for _, rr := range r {
if len(rr.Scope) == 0 || utils.IsSubSet(scopes, utils.SplitIgnoreEmpty(rr.Scope, " ")) {
ret = append(ret, rr)
}
}
return
}
func (r Restrictions) WithAudiences(audiences []string) (ret Restrictions) {
log.WithField("audiences", audiences).WithField("len", len(audiences)).Trace("Filter restrictions for audiences")
if len(audiences) == 0 {
log.Trace("audiences empty, returning all restrictions")
return r
}
for _, rr := range r {
if len(rr.Audiences) == 0 || utils.IsSubSet(audiences, rr.Audiences) {
ret = append(ret, rr)
}
}
return
}
type TokenUsages []TokenUsage
type TokenUsage struct {
STID string `db:"ST_id"`
UsagesOtherUsed uint `db:"usages_other"`
UsagesATUsed uint `db:"usages_AT"`
}
// Scan implements the sql.Scanner interface.
func (r *Restrictions) Scan(src interface{}) error {
if src == nil {
return nil
}
val := src.([]uint8)
err := json.Unmarshal(val, &r)
return err
}
// Value implements the driver.Valuer interface
func (r Restrictions) Value() (driver.Value, error) {
if len(r) == 0 {
return nil, nil
}
return json.Marshal(r)
}
// GetExpires gets the maximum (latest) expiration time of all restrictions
func (r *Restrictions) GetExpires() int64 {
if r == nil {
return 0
}
exp := int64(0)
for _, rr := range *r {
if rr.ExpiresAt == 0 { // if one entry has no expiry the max expiry is 0
return 0
}
if rr.ExpiresAt > exp {
exp = rr.ExpiresAt
}
}
return exp
}
// GetNotBefore gets the minimal (earliest) notbefore time of all restrictions
func (r *Restrictions) GetNotBefore() int64 {
if r == nil || len(*r) == 0 {
return 0
}
nbf := int64(math.MaxInt64)
for _, rr := range *r {
if rr.NotBefore == 0 { // if one entry has no notbefore the min notbefore is 0
return 0
}
if rr.NotBefore < nbf {
nbf = rr.NotBefore
}
}
return nbf
}
// GetScopes returns the union of all scopes, i.e. all scopes that must be requested at the issuer
func (r *Restrictions) GetScopes() (scopes []string) {
for _, rr := range *r {
scopes = append(scopes, utils.SplitIgnoreEmpty(rr.Scope, " ")...)
}
scopes = utils.UniqueSlice(scopes)
return
}
// GetAudiences returns the union of all audiences, i.e. all audiences that must be requested at the issuer
func (r *Restrictions) GetAudiences() (auds []string) {
for _, rr := range *r {
auds = append(auds, rr.Audiences...)
}
auds = utils.UniqueSlice(auds)
return
}
// SetMaxScopes sets the maximum scopes, i.e. all scopes are stripped from the restrictions if not included in the passed argument. This is used to eliminate requested scopes that are dropped by the provider. Don't use it to eliminate scopes that are not enabled for the oidc client, because it also could be a custom scope.
func (r *Restrictions) SetMaxScopes(mScopes []string) {
for _, rr := range *r {
rScopes := utils.SplitIgnoreEmpty(rr.Scope, " ")
okScopes := utils.IntersectSlices(mScopes, rScopes)
rr.Scope = strings.Join(okScopes, " ")
}
}
// SetMaxAudiences sets the maximum audiences, i.e. all audiences are stripped from the restrictions if not included in the passed argument. This is used to eliminate requested audiences that are dropped by the provider.
func (r *Restrictions) SetMaxAudiences(mAud []string) {
for _, rr := range *r {
rr.Audiences = utils.IntersectSlices(mAud, rr.Audiences)
}
}
func Tighten(old, wanted Restrictions) (res Restrictions) {
if len(old) == 0 {
return wanted
}
base := Restrictions{}
copier.Copy(&base, &old)
for i, a := range wanted {
for _, o := range base {
if a.IsTighterThan(o) {
res = append(res, a)
base = append(base[:i], base[i+1:]...)
break
}
}
}
if len(res) == 0 { // if all from wanted are dropped, because they are not tighter than old, than use old
return old
}
return
}
func (r *Restrictions) RemoveIndex(i int) {
copy((*r)[i:], (*r)[i+1:]) // Shift r[i+1:] left one index.
// r[len(r)-1] = "" // Erase last element (write zero value).
*r = (*r)[:len(*r)-1] // Truncate slice.
}
func (r *Restriction) IsTighterThan(b Restriction) bool {
if r.NotBefore < b.NotBefore {
return false
}
if r.ExpiresAt == 0 && b.ExpiresAt != 0 || r.ExpiresAt > b.ExpiresAt && b.ExpiresAt != 0 {
return false
}
rScopes := utils.SplitIgnoreEmpty(r.Scope, " ")
if r.Scope == "" {
rScopes = []string{}
}
bScopes := utils.SplitIgnoreEmpty(b.Scope, " ")
if b.Scope == "" {
bScopes = []string{}
}
if len(rScopes) == 0 && len(bScopes) > 0 || !utils.IsSubSet(rScopes, bScopes) && len(bScopes) != 0 {
return false
}
if len(r.Audiences) == 0 && len(b.Audiences) > 0 || !utils.IsSubSet(r.Audiences, b.Audiences) && len(b.Audiences) != 0 {
return false
}
if len(r.IPs) == 0 && len(b.IPs) > 0 || !utils.IPsAreSubSet(r.IPs, b.IPs) && len(b.IPs) != 0 {
return false
}
if len(r.GeoIPWhite) == 0 && len(b.GeoIPWhite) > 0 || !utils.IsSubSet(r.GeoIPWhite, b.GeoIPWhite) && len(b.GeoIPWhite) != 0 {
return false
}
if !utils.IsSubSet(b.GeoIPBlack, r.GeoIPBlack) { // for Blacklist r must have all the values from b to be tighter
return false
}
if utils.CompareNullableIntsWithNilAsInfinity(r.UsagesAT, b.UsagesAT) > 0 {
return false
}
if utils.CompareNullableIntsWithNilAsInfinity(r.UsagesOther, b.UsagesOther) > 0 {
return false
}
return true
}