From c03235c70c122d41a5e8444617aba159c39fe9cb Mon Sep 17 00:00:00 2001
From: zachmann <gabriel.zachmann@kit.edu>
Date: Fri, 13 Nov 2020 16:48:46 +0100
Subject: [PATCH] add polling code; improve existing auth flow code

---
 internal/db/dbModels/accessToken.go           |  7 +-
 internal/db/dbModels/authcodeInfo.go          |  4 +-
 internal/db/dbModels/pollingCodeTmpST.go      | 37 ++++++++
 .../token/super/pkg/pollingCodeRequest.go     | 10 +++
 .../token/super/pkg/superTokenResponse.go     |  6 +-
 .../token/super/polling/pollingEndpoint.go    | 85 +++++++++++++++++++
 .../token/super/supertokenEndpoint.go         |  5 +-
 internal/model/apiError.go                    | 24 ++++--
 internal/oidc/authcode/authcode.go            | 44 ++++++++--
 internal/supertoken/pkg/supertoken.go         |  9 +-
 .../supertoken/restrictions/restriction.go    | 16 ++++
 internal/utils/utils.go                       |  8 ++
 12 files changed, 230 insertions(+), 25 deletions(-)
 create mode 100644 internal/db/dbModels/pollingCodeTmpST.go
 create mode 100644 internal/endpoints/token/super/pkg/pollingCodeRequest.go
 create mode 100644 internal/endpoints/token/super/polling/pollingEndpoint.go

diff --git a/internal/db/dbModels/accessToken.go b/internal/db/dbModels/accessToken.go
index 72ac056c..1e8c1fec 100644
--- a/internal/db/dbModels/accessToken.go
+++ b/internal/db/dbModels/accessToken.go
@@ -37,10 +37,10 @@ func (t *AccessToken) toDBObject() accessToken {
 func (t *AccessToken) getDBAttributes(atID uint64) (attrs []AccessTokenAttribute, err error) {
 	var scopeAttrID uint64
 	var audAttrID uint64
-	if err = db.DB().QueryRow(`SELECT id FROM Attributes WHERE attribute=?`, "scope").Scan(scopeAttrID); err != nil {
+	if err = db.DB().QueryRow(`SELECT id FROM Attributes WHERE attribute=?`, "scope").Scan(&scopeAttrID); err != nil {
 		return
 	}
-	if err = db.DB().QueryRow(`SELECT id FROM Attributes WHERE attribute=?`, "audience").Scan(audAttrID); err != nil {
+	if err = db.DB().QueryRow(`SELECT id FROM Attributes WHERE attribute=?`, "audience").Scan(&audAttrID); err != nil {
 		return
 	}
 	for _, s := range t.Scopes {
@@ -73,6 +73,9 @@ func (t *AccessToken) Store() error {
 				return err
 			}
 			attrs, err := t.getDBAttributes(uint64(atID))
+			if err != nil {
+				return err
+			}
 			if _, err := tx.NamedExec(`INSERT INTO AT_Attributes (AT_id, attribute_id, attribute) VALUES (:AT_id, :attribute_id, :attribute)`, attrs); err != nil {
 				return err
 			}
diff --git a/internal/db/dbModels/authcodeInfo.go b/internal/db/dbModels/authcodeInfo.go
index 42bcd611..a969544e 100644
--- a/internal/db/dbModels/authcodeInfo.go
+++ b/internal/db/dbModels/authcodeInfo.go
@@ -4,6 +4,8 @@ import (
 	"database/sql"
 	"log"
 
+	"github.com/zachmann/mytoken/internal/config"
+
 	"github.com/jmoiron/sqlx"
 
 	"github.com/zachmann/mytoken/internal/db"
@@ -56,7 +58,7 @@ func (i *AuthFlowInfo) Store() error {
 	store := i.toAuthFlowInfo()
 	return db.Transact(func(tx *sqlx.Tx) error {
 		if i.PollingCode != "" {
-			res, err := tx.Exec(`INSERT INTO PollingCodes (polling_code) VALUES(?)`, i.PollingCode)
+			res, err := tx.Exec(`INSERT INTO PollingCodes (polling_code, expires_in) VALUES(?, ?)`, i.PollingCode, config.Get().Polling.PollingCodeExpiresAfter)
 			if err != nil {
 				return err
 			}
diff --git a/internal/db/dbModels/pollingCodeTmpST.go b/internal/db/dbModels/pollingCodeTmpST.go
new file mode 100644
index 00000000..ea4d7dd3
--- /dev/null
+++ b/internal/db/dbModels/pollingCodeTmpST.go
@@ -0,0 +1,37 @@
+package dbModels
+
+import (
+	"database/sql"
+	"errors"
+
+	"github.com/zachmann/mytoken/internal/db"
+)
+
+type PollingCodeStatus struct {
+	Found   bool
+	Expired bool
+}
+
+func (p *PollingCodeStatus) Scan(src interface{}) error {
+	if src == nil {
+		p.Found = false
+		return nil
+	}
+	val := src.(int64)
+	p.Found = true
+	if val == 0 {
+		p.Expired = true
+	}
+	return nil
+}
+
+func CheckPollingCode(pollingCode string) (p PollingCodeStatus, err error) {
+	if err = db.DB().Get(&p, `SELECT CURRENT_TIMESTAMP() <= expires_at AS valid FROM PollingCodes WHERE polling_code=?`, pollingCode); err != nil {
+		if errors.Is(err, sql.ErrNoRows) {
+			err = nil // polling code was not found, but this is fine
+			return    // p.Found is false
+		}
+		return
+	}
+	return
+}
diff --git a/internal/endpoints/token/super/pkg/pollingCodeRequest.go b/internal/endpoints/token/super/pkg/pollingCodeRequest.go
new file mode 100644
index 00000000..4930ab14
--- /dev/null
+++ b/internal/endpoints/token/super/pkg/pollingCodeRequest.go
@@ -0,0 +1,10 @@
+package pkg
+
+import (
+	"github.com/zachmann/mytoken/internal/model"
+)
+
+type PollingCodeRequest struct {
+	GrantType   model.GrantType `json:"grant_type"`
+	PollingCode string          `json:"polling_code"`
+}
diff --git a/internal/endpoints/token/super/pkg/superTokenResponse.go b/internal/endpoints/token/super/pkg/superTokenResponse.go
index e81873a7..2ee7af6c 100644
--- a/internal/endpoints/token/super/pkg/superTokenResponse.go
+++ b/internal/endpoints/token/super/pkg/superTokenResponse.go
@@ -7,7 +7,7 @@ import (
 
 type SuperTokenResponse struct {
 	SuperToken   string                    `json:"super_token"`
-	ExpiresIn    uint64                    `json:"expires_in"`
-	Restrictions restrictions.Restrictions `json:"restrictions"`
-	Capabilities capabilities.Capabilities `json:"capabilities"`
+	ExpiresIn    uint64                    `json:"expires_in,omitempty"`
+	Restrictions restrictions.Restrictions `json:"restrictions,omitempty"`
+	Capabilities capabilities.Capabilities `json:"capabilities,omitempty"`
 }
diff --git a/internal/endpoints/token/super/polling/pollingEndpoint.go b/internal/endpoints/token/super/polling/pollingEndpoint.go
new file mode 100644
index 00000000..e953d7ee
--- /dev/null
+++ b/internal/endpoints/token/super/polling/pollingEndpoint.go
@@ -0,0 +1,85 @@
+package polling
+
+import (
+	"encoding/json"
+	"log"
+
+	supertoken "github.com/zachmann/mytoken/internal/supertoken/pkg"
+
+	"github.com/jmoiron/sqlx"
+
+	"github.com/zachmann/mytoken/internal/db"
+
+	"github.com/zachmann/mytoken/internal/db/dbModels"
+
+	"github.com/gofiber/fiber/v2"
+	response "github.com/zachmann/mytoken/internal/endpoints/token/super/pkg"
+	"github.com/zachmann/mytoken/internal/model"
+)
+
+func HandlePollingCode(ctx *fiber.Ctx) error {
+	req := response.PollingCodeRequest{}
+	if err := json.Unmarshal(ctx.Body(), &req); err != nil {
+		res := model.Response{
+			Status:   fiber.StatusBadRequest,
+			Response: model.BadRequestError(err.Error()),
+		}
+		return res.Send(ctx)
+	}
+	res := handlePollingCode(req)
+	return res.Send(ctx)
+}
+
+func handlePollingCode(req response.PollingCodeRequest) model.Response {
+	pollingCode := req.PollingCode
+	log.Printf("Handle polling code '%s'", pollingCode)
+	pollingCodeStatus, err := dbModels.CheckPollingCode(pollingCode)
+	if err != nil {
+		return model.ErrorToInternalServerErrorResponse(err)
+	}
+	if !pollingCodeStatus.Found {
+		log.Printf("Polling code '%s' not known", pollingCode)
+		return model.Response{
+			Status:   fiber.StatusUnauthorized,
+			Response: model.APIErrorBadPollingCode,
+		}
+	}
+	if pollingCodeStatus.Expired {
+		log.Printf("Polling code '%s' expired", pollingCode)
+		return model.Response{
+			Status:   fiber.StatusUnauthorized,
+			Response: model.APIErrorPollingCodeExpired,
+		}
+	}
+	var token string
+	if err := db.Transact(func(tx *sqlx.Tx) error {
+		if err := tx.Get(&token, `SELECT token FROM TmpST_by_polling_code WHERE polling_code=? AND CURRENT_TIMESTAMP() <= polling_code_expires_at`, pollingCode); err != nil {
+			return err
+		}
+		log.Printf("Retrieved token '%s' for polling code '%s' from db", token, pollingCode)
+		if token == "" {
+			return nil
+		}
+		if _, err := tx.Exec(`DELETE FROM PollingCodes WHERE polling_code=?`, pollingCode); err != nil {
+			return err
+		}
+		return nil
+	}); err != nil {
+		return model.ErrorToInternalServerErrorResponse(err)
+	}
+	if token == "" {
+		return model.Response{
+			Status:   fiber.StatusPreconditionRequired,
+			Response: model.APIErrorAuthorizationPending,
+		}
+	}
+	st, err := supertoken.ParseJWT(token)
+	if err != nil {
+		return model.ErrorToInternalServerErrorResponse(err)
+	}
+	log.Printf("The JWT was parsed as '%+v'", st)
+	return model.Response{
+		Status:   fiber.StatusOK,
+		Response: st.ToSuperTokenResponse(token),
+	}
+}
diff --git a/internal/endpoints/token/super/supertokenEndpoint.go b/internal/endpoints/token/super/supertokenEndpoint.go
index f25e16b5..62858423 100644
--- a/internal/endpoints/token/super/supertokenEndpoint.go
+++ b/internal/endpoints/token/super/supertokenEndpoint.go
@@ -1,6 +1,7 @@
 package super
 
 import (
+	"github.com/zachmann/mytoken/internal/endpoints/token/super/polling"
 	"github.com/zachmann/mytoken/internal/utils/ctxUtils"
 
 	"github.com/gofiber/fiber/v2"
@@ -15,9 +16,9 @@ func HandleSuperTokenEndpoint(ctx *fiber.Ctx) error {
 		return model.ResponseNYI.Send(ctx)
 	case model.GrantTypeOIDCFlow:
 		return handleOIDCFlow(ctx)
-	case model.GrantTypeAccessToken:
-		return model.ResponseNYI.Send(ctx)
 	case model.GrantTypePollingCode:
+		return polling.HandlePollingCode(ctx)
+	case model.GrantTypeAccessToken:
 		return model.ResponseNYI.Send(ctx)
 	case model.GrantTypePrivateKeyJWT:
 		return model.ResponseNYI.Send(ctx)
diff --git a/internal/model/apiError.go b/internal/model/apiError.go
index be19ca76..1d115b9b 100644
--- a/internal/model/apiError.go
+++ b/internal/model/apiError.go
@@ -16,12 +16,15 @@ var ResponseNYI = Response{fiber.StatusNotImplemented, APIErrorNYI}
 
 // Predefined errors
 var (
-	APIErrorUnknownIssuer    = APIError{ErrorInvalidRequest, "The provided issuer is not supported"}
-	APIErrorStateMismatch    = APIError{ErrorInvalidRequest, "State mismatched"}
-	APIErrorUnknownOIDCFlow  = APIError{ErrorInvalidGrant, "Unknown oidc_flow"}
-	APIErrorUnknownGrantType = APIError{ErrorInvalidGrant, "Unknown grant_type"}
-	APIErrorNoRefreshToken   = APIError{ErrorOIDC, "Did not receive a refresh token"}
-	APIErrorNYI              = APIError{ErrorNYI, ""}
+	APIErrorUnknownIssuer        = APIError{ErrorInvalidRequest, "The provided issuer is not supported"}
+	APIErrorStateMismatch        = APIError{ErrorInvalidRequest, "State mismatched"}
+	APIErrorUnknownOIDCFlow      = APIError{ErrorInvalidGrant, "Unknown oidc_flow"}
+	APIErrorUnknownGrantType     = APIError{ErrorInvalidGrant, "Unknown grant_type"}
+	APIErrorBadPollingCode       = APIError{ErrorAccessDenied, "Bad polling_code"}
+	APIErrorPollingCodeExpired   = APIError{ErrorExpiredToken, "polling_code is expired"}
+	APIErrorAuthorizationPending = ErrorWithoutDescription(ErrorAuthorizationPending)
+	APIErrorNoRefreshToken       = APIError{ErrorOIDC, "Did not receive a refresh token"}
+	APIErrorNYI                  = ErrorWithoutDescription(ErrorNYI)
 )
 
 // Predefined OAuth2/OIDC errors
@@ -34,6 +37,9 @@ const (
 	ErrorInvalidScope         = "invalid_scope"
 	ErrorInvalidToken         = "invalid_token"
 	ErrorInsufficientScope    = "insufficient_scope"
+	ErrorExpiredToken         = "expired_token"
+	ErrorAccessDenied         = "access_denied"
+	ErrorAuthorizationPending = "authorization_pending"
 )
 
 // Additional Mytoken errors
@@ -77,3 +83,9 @@ func BadRequestError(errorDescription string) APIError {
 		ErrorDescription: errorDescription,
 	}
 }
+
+func ErrorWithoutDescription(error string) APIError {
+	return APIError{
+		Error: error,
+	}
+}
diff --git a/internal/oidc/authcode/authcode.go b/internal/oidc/authcode/authcode.go
index f1f8e8d1..1c158082 100644
--- a/internal/oidc/authcode/authcode.go
+++ b/internal/oidc/authcode/authcode.go
@@ -7,6 +7,11 @@ import (
 	"errors"
 	"fmt"
 	"log"
+	"strings"
+
+	"github.com/zachmann/mytoken/internal/supertoken/restrictions"
+
+	"github.com/dgrijalva/jwt-go"
 
 	"github.com/jmoiron/sqlx"
 	"github.com/zachmann/mytoken/internal/db"
@@ -34,7 +39,7 @@ func Init() {
 }
 
 const stateLen = 16
-const pollingCodeLen = 16
+const pollingCodeLen = 32
 
 type stateInfo struct {
 	Native bool
@@ -62,14 +67,18 @@ func parseState(state string) stateInfo {
 	return info
 }
 
-func authorizationURL(provider *config.ProviderConf, native bool) (string, string) {
+func authorizationURL(provider *config.ProviderConf, restrictions restrictions.Restrictions, native bool) (string, string) {
 	log.Printf("Generating authorization url")
+	scopes := restrictions.GetScopes()
+	if len(scopes) <= 0 {
+		scopes = provider.Scopes
+	}
 	oauth2Config := oauth2.Config{
 		ClientID:     provider.ClientID,
 		ClientSecret: provider.ClientSecret,
 		Endpoint:     provider.Provider.Endpoint(),
 		RedirectURL:  redirectURL,
-		Scopes:       provider.Scopes, //TODO use restrictions
+		Scopes:       scopes,
 	}
 	state := createState(stateInfo{Native: native})
 	additionalParams := []oauth2.AuthCodeOption{oauth2.ApprovalForce}
@@ -78,7 +87,10 @@ func authorizationURL(provider *config.ProviderConf, native bool) (string, strin
 	} else if !utils.StringInSlice(oidc.ScopeOfflineAccess, oauth2Config.Scopes) {
 		oauth2Config.Scopes = append(oauth2Config.Scopes, oidc.ScopeOfflineAccess)
 	}
-	//TODO add audience from restriction
+	auds := restrictions.GetAudiences()
+	if len(auds) > 0 {
+		additionalParams = append(additionalParams, oauth2.SetAuthURLParam("audience", strings.Join(auds, " ")))
+	}
 
 	return oauth2Config.AuthCodeURL(state, additionalParams...), state
 }
@@ -100,7 +112,7 @@ func InitAuthCodeFlow(body []byte) model.Response {
 		}
 	}
 
-	authURL, state := authorizationURL(provider, req.Native())
+	authURL, state := authorizationURL(provider, req.Restrictions, req.Native())
 	authFlowInfo := dbModels.AuthFlowInfo{
 		State:        state,
 		Issuer:       provider.Issuer,
@@ -173,6 +185,22 @@ func CodeExchange(state, code string, networkData model.NetworkData) model.Respo
 			Response: model.APIErrorNoRefreshToken,
 		}
 	}
+	scopes := authInfo.Restrictions.GetScopes()
+	scopesStr, ok := token.Extra("scope").(string)
+	if ok && scopesStr != "" {
+		scopes = strings.Split(scopesStr, " ")
+		authInfo.Restrictions.SetMaxScopes(scopes) // Update restrictions with correct scopes
+	}
+	audiences := authInfo.Restrictions.GetAudiences()
+	if atJWT, _ := jwt.Parse(token.AccessToken, nil); atJWT != nil {
+		if claims, ok := atJWT.Claims.(jwt.MapClaims); ok {
+			if tmp, ok := claims["aud"].([]string); ok {
+				audiences = tmp
+			}
+		}
+	}
+	authInfo.Restrictions.SetMaxAudiences(audiences) // Update restrictions with correct audiences
+
 	oidcSub, err := getSubjectFromUserinfo(provider.Provider, token)
 	if err != nil {
 		return model.ErrorToInternalServerErrorResponse(err)
@@ -186,8 +214,8 @@ func CodeExchange(state, code string, networkData model.NetworkData) model.Respo
 		IP:        networkData.IP,
 		Comment:   "Initial Access Token from authorization code flow",
 		STID:      ste.ID,
-		Scopes:    nil, //TODO
-		Audiences: nil, //TODO
+		Scopes:    scopes,
+		Audiences: audiences,
 	}
 	if err := at.Store(); err != nil {
 		return model.ErrorToInternalServerErrorResponse(err)
@@ -214,7 +242,7 @@ func CodeExchange(state, code string, networkData model.NetworkData) model.Respo
 	}
 	return model.Response{
 		Status:   fiber.StatusOK,
-		Response: ste.Token.ToSuperTokenResponse(), //TODO
+		Response: ste.Token.ToSuperTokenResponse(""), //TODO redirect
 	}
 }
 
diff --git a/internal/supertoken/pkg/supertoken.go b/internal/supertoken/pkg/supertoken.go
index 74d25d76..728923df 100644
--- a/internal/supertoken/pkg/supertoken.go
+++ b/internal/supertoken/pkg/supertoken.go
@@ -92,10 +92,13 @@ func (st *SuperToken) Valid() error {
 	return nil
 }
 
-// ToSuperTokenResponse returns a SuperTokenResponse for this token. It requires that jwt is set, i.e. ToJWT must have been called earlier on this token. This is always the case, if the token has been stored.
-func (st *SuperToken) ToSuperTokenResponse() response.SuperTokenResponse {
+// ToSuperTokenResponse returns a SuperTokenResponse for this token. It requires that jwt is set or that the jwt is passed as argument; if not passed as argument ToJWT must have been called earlier on this token to set jwt. This is always the case, if the token has been stored.
+func (st *SuperToken) ToSuperTokenResponse(jwt string) response.SuperTokenResponse {
+	if jwt == "" {
+		jwt = st.jwt
+	}
 	return response.SuperTokenResponse{
-		SuperToken:   st.jwt,
+		SuperToken:   jwt,
 		ExpiresIn:    st.ExpiresIn(),
 		Restrictions: st.Restrictions,
 		Capabilities: st.Capabilities,
diff --git a/internal/supertoken/restrictions/restriction.go b/internal/supertoken/restrictions/restriction.go
index 10f814f3..9ba338e3 100644
--- a/internal/supertoken/restrictions/restriction.go
+++ b/internal/supertoken/restrictions/restriction.go
@@ -106,6 +106,22 @@ func (r *Restrictions) GetAudiences() (auds []string) {
 	return
 }
 
+// SetMaxScopes sets the maximum scopes, i.e. all scopes are stripped from the restrictions if not included in the passed argument. This is used to eliminate requested scopes that are dropped by the provider. Don't use it to eliminate scopes that are not enabled for the oidc client, because it also could be a custom scope.
+func (r *Restrictions) SetMaxScopes(mScopes []string) {
+	for _, rr := range *r {
+		rScopes := strings.Split(rr.Scope, " ")
+		okScopes := utils.IntersectSlices(mScopes, rScopes)
+		rr.Scope = strings.Join(okScopes, " ")
+	}
+}
+
+// SetMaxAudiences sets the maximum audiences, i.e. all audiences are stripped from the restrictions if not included in the passed argument. This is used to eliminate requested audiences that are dropped by the provider.
+func (r *Restrictions) SetMaxAudiences(mAud []string) {
+	for _, rr := range *r {
+		rr.Audiences = utils.IntersectSlices(mAud, rr.Audiences)
+	}
+}
+
 func Tighten(old, wanted Restrictions) (res Restrictions) {
 	if len(old) == 0 {
 		return wanted
diff --git a/internal/utils/utils.go b/internal/utils/utils.go
index 9c27eaf6..9ebbfb8a 100644
--- a/internal/utils/utils.go
+++ b/internal/utils/utils.go
@@ -131,3 +131,11 @@ func UniqueSlice(a []string) (unique []string) {
 func SliceUnion(a, b []string) []string {
 	return UniqueSlice(append(a, b...))
 }
+
+func GetTimeIn(seconds int64) time.Time {
+	return time.Now().Add(time.Duration(seconds) * time.Second)
+}
+
+func GetUnixTimeIn(seconds int64) int64 {
+	return GetTimeIn(seconds).Unix()
+}
-- 
GitLab