Skip to content
Snippets Groups Projects
Commit d9587630 authored by Gabriel Zachmann's avatar Gabriel Zachmann
Browse files

add PKCE

parent 61331923
No related branches found
No related tags found
No related merge requests found
......@@ -2,12 +2,13 @@ package dbmigrate
var v0_3_0_Before = []string{
// Tables
"ALTER TABLE AuthInfo ADD rotation json NULL",
"DROP TRIGGER updtrigger",
"ALTER TABLE RT_EncryptionKeys ADD CONSTRAINT RT_EncryptionKeys_FK_2 FOREIGN KEY (MT_id) REFERENCES MTokens(id) ON DELETE CASCADE ON UPDATE CASCADE",
"TRUNCATE TABLE AuthInfo",
"ALTER TABLE AuthInfo ADD rotation json NULL",
"ALTER TABLE AuthInfo ADD response_type varchar(128) NOT NULL",
"ALTER TABLE AuthInfo ADD max_token_len INT DEFAULT NULL NULL",
"ALTER TABLE AuthInfo ADD code_verifier varchar(128) NULL",
"ALTER TABLE TransferCodesAttributes ADD max_token_len INT NULL",
"CREATE OR REPLACE" +
"ALGORITHM = UNDEFINED VIEW `mytoken_test`.`TransferCodes` AS" +
......
......@@ -35,6 +35,7 @@ type AuthFlowInfoOut struct {
Rotation *api.Rotation
ResponseType model.ResponseType
MaxTokenLen int
CodeVerifier string
}
type authFlowInfo struct {
......@@ -49,6 +50,7 @@ type authFlowInfo struct {
Rotation *api.Rotation
ResponseType model.ResponseType `db:"response_type"`
MaxTokenLen *int `db:"max_token_len"`
CodeVerifier db.NullString `db:"code_verifier"`
}
func (i *AuthFlowInfo) toAuthFlowInfo() *authFlowInfo {
......@@ -78,6 +80,7 @@ func (i *authFlowInfo) toAuthFlowInfo() *AuthFlowInfoOut {
PollingCode: bool(i.PollingCode),
Rotation: i.Rotation,
ResponseType: i.ResponseType,
CodeVerifier: i.CodeVerifier.String,
}
if i.MaxTokenLen != nil {
o.MaxTokenLen = *i.MaxTokenLen
......@@ -111,7 +114,7 @@ func GetAuthFlowInfoByState(state *state.State) (*AuthFlowInfoOut, error) {
if err := db.Transact(func(tx *sqlx.Tx) error {
return errors.WithStack(tx.Get(&info,
`SELECT state_h, iss, restrictions, capabilities, subtoken_capabilities,
name, polling_code, rotation, response_type, max_token_len FROM AuthInfo
name, polling_code, rotation, response_type, max_token_len, code_verifier FROM AuthInfo
WHERE state_h=? AND expires_at >= CURRENT_TIMESTAMP()`,
state))
}); err != nil {
......@@ -138,3 +141,11 @@ func UpdateTokenInfoByState(tx *sqlx.Tx, state *state.State, r restrictions.Rest
return errors.WithStack(err)
})
}
// SetCodeVerifier stores the passed PKCE code verifier
func SetCodeVerifier(tx *sqlx.Tx, state *state.State, verifier string) error {
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
_, err := tx.Exec(`UPDATE AuthInfo SET code_verifier=? WHERE state_h=?`, verifier, state)
return errors.WithStack(err)
})
}
......@@ -7,9 +7,12 @@ import (
"strings"
"github.com/gofiber/fiber/v2"
"github.com/jmoiron/sqlx"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"github.com/oidc-mytoken/server/internal/db"
"github.com/oidc-mytoken/server/internal/oidc/pkce"
"github.com/oidc-mytoken/server/internal/server/httpStatus"
"github.com/oidc-mytoken/server/internal/utils/errorfmt"
......@@ -130,9 +133,15 @@ func HandleConsentPost(ctx *fiber.Ctx) error {
}.Send(ctx)
}
}
if err = authcodeinforepo.UpdateTokenInfoByState(
nil, oState, req.Restrictions, req.Capabilities, req.SubtokenCapabilities, req.Rotation, req.TokenName,
); err != nil {
pkceCode := pkce.NewS256PKCE(utils.RandASCIIString(32))
if err = db.Transact(func(tx *sqlx.Tx) error {
if err = authcodeinforepo.UpdateTokenInfoByState(
tx, oState, req.Restrictions, req.Capabilities, req.SubtokenCapabilities, req.Rotation, req.TokenName,
); err != nil {
return err
}
return authcodeinforepo.SetCodeVerifier(tx, oState, pkceCode.Verifier())
}); err != nil {
log.Errorf("%s", errorfmt.Full(err))
return model.ErrorToInternalServerErrorResponse(err).Send(ctx)
}
......@@ -143,7 +152,8 @@ func HandleConsentPost(ctx *fiber.Ctx) error {
Response: api.ErrorUnknownIssuer,
}.Send(ctx)
}
authURL := authcode.GetAuthorizationURL(provider, oState.State(), req.Restrictions)
pkceChallenge, _ := pkceCode.Challenge()
authURL := authcode.GetAuthorizationURL(provider, oState.State(), pkceChallenge, req.Restrictions)
return model.Response{
Status: httpStatus.StatusOKForward,
Response: map[string]string{
......
......@@ -7,13 +7,16 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gofiber/fiber/v2"
"github.com/jmoiron/sqlx"
"github.com/oidc-mytoken/server/internal/utils/cookies"
"github.com/oidc-mytoken/server/internal/utils/errorfmt"
"github.com/pkg/errors"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"github.com/oidc-mytoken/server/internal/oidc/pkce"
"github.com/oidc-mytoken/server/internal/utils/cookies"
"github.com/oidc-mytoken/server/internal/utils/errorfmt"
"github.com/oidc-mytoken/api/v0"
"github.com/oidc-mytoken/server/internal/config"
"github.com/oidc-mytoken/server/internal/db"
"github.com/oidc-mytoken/server/internal/db/dbrepo/accesstokenrepo"
......@@ -46,7 +49,7 @@ func Init() {
}
// GetAuthorizationURL creates a authorization url
func GetAuthorizationURL(provider *config.ProviderConf, oState string, restrictions restrictions.Restrictions) string {
func GetAuthorizationURL(provider *config.ProviderConf, oState, pkceChallenge string, restrictions restrictions.Restrictions) string {
log.Debug("Generating authorization url")
scopes := restrictions.GetScopes()
if len(scopes) <= 0 {
......@@ -59,7 +62,11 @@ func GetAuthorizationURL(provider *config.ProviderConf, oState string, restricti
RedirectURL: redirectURL,
Scopes: scopes,
}
additionalParams := []oauth2.AuthCodeOption{oauth2.ApprovalForce}
additionalParams := []oauth2.AuthCodeOption{
oauth2.ApprovalForce,
oauth2.SetAuthURLParam("code_challenge", pkceChallenge),
oauth2.SetAuthURLParam("code_challenge_method", pkce.TransformationS256.String()),
}
if issuerUtils.CompareIssuerURLs(provider.Issuer, issuer.GOOGLE) {
additionalParams = append(additionalParams, oauth2.AccessTypeOffline)
} else if !utils.StringInSlice(oidc.ScopeOfflineAccess, oauth2Config.Scopes) {
......@@ -132,6 +139,8 @@ func StartAuthCodeFlow(ctx *fiber.Ctx, oidcReq response.OIDCFlowRequest) *model.
}
}
//TODO don't return json
// CodeExchange performs an oidc code exchange it creates the mytoken and stores it in the database
func CodeExchange(oState *state.State, code string, networkData api.ClientMetaData) *model.Response {
log.Debug("Handle code exchange")
......@@ -159,7 +168,7 @@ func CodeExchange(oState *state.State, code string, networkData api.ClientMetaDa
Endpoint: provider.Endpoints.OAuth2(),
RedirectURL: redirectURL,
}
token, err := oauth2Config.Exchange(context.Get(), code)
token, err := oauth2Config.Exchange(context.Get(), code, oauth2.SetAuthURLParam("code_verifier", authInfo.CodeVerifier))
if err != nil {
var e *oauth2.RetrieveError
if errors.As(err, &e) {
......
package pkce
import (
"crypto/sha256"
"encoding/base64"
"github.com/pkg/errors"
)
// PKCE is a type holding the information for a PKCE flow
type PKCE struct {
verifier string
challenge string
method PKCEMethod
}
// PKCEMethod is a type for the code challenge methods
type PKCEMethod string
// Defines for the possible PKCEMethod
const (
TransformationPlain = PKCEMethod("plain")
TransformationS256 = PKCEMethod("S256")
)
func (m PKCEMethod) String() string {
return string(m)
}
// NewPKCE creates a new PKCE for the passed verifier and PKCEMethod
func NewPKCE(verifier string, method PKCEMethod) *PKCE {
return &PKCE{
verifier: verifier,
method: method,
}
}
// NewS256PKCE creates a new PKCE for the passed verifier and the PKCEMethod TransformationS256
func NewS256PKCE(verifier string) *PKCE {
return NewPKCE(verifier, TransformationS256)
}
// Verifier returns the code_verifier
func (pkce PKCE) Verifier() string {
return pkce.verifier
}
// Challenge returns the code_challenge according to the defined PKCEMethod
func (pkce *PKCE) Challenge() (string, error) {
var err error
if pkce.challenge == "" {
pkce.challenge, err = pkce.transform()
}
return pkce.challenge, err
}
func (pkce PKCE) transform() (string, error) {
switch pkce.method {
case TransformationPlain:
return pkce.plain(), nil
case TransformationS256:
return pkce.s256(), nil
default:
return "", errors.New("unknown code_challenge_method")
}
}
func (pkce PKCE) plain() string {
return pkce.verifier
}
func (pkce PKCE) s256() string {
hash := sha256.Sum256([]byte(pkce.verifier))
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment