diff --git a/internal/db/dbmigrate/v0.3.0.go b/internal/db/dbmigrate/v0.3.0.go index d9842c576704aafe21e503dfb2d7e36d46b8758b..e5a6e956e50087f025114ab7d045acf317da6666 100644 --- a/internal/db/dbmigrate/v0.3.0.go +++ b/internal/db/dbmigrate/v0.3.0.go @@ -2,12 +2,13 @@ package dbmigrate var v0_3_0_Before = []string{ // Tables - "ALTER TABLE AuthInfo ADD rotation json NULL", "DROP TRIGGER updtrigger", "ALTER TABLE RT_EncryptionKeys ADD CONSTRAINT RT_EncryptionKeys_FK_2 FOREIGN KEY (MT_id) REFERENCES MTokens(id) ON DELETE CASCADE ON UPDATE CASCADE", "TRUNCATE TABLE AuthInfo", + "ALTER TABLE AuthInfo ADD rotation json NULL", "ALTER TABLE AuthInfo ADD response_type varchar(128) NOT NULL", "ALTER TABLE AuthInfo ADD max_token_len INT DEFAULT NULL NULL", + "ALTER TABLE AuthInfo ADD code_verifier varchar(128) NULL", "ALTER TABLE TransferCodesAttributes ADD max_token_len INT NULL", "CREATE OR REPLACE" + "ALGORITHM = UNDEFINED VIEW `mytoken_test`.`TransferCodes` AS" + diff --git a/internal/db/dbrepo/authcodeinforepo/authcodeInfo.go b/internal/db/dbrepo/authcodeinforepo/authcodeInfo.go index 29e5c11e98d98c932029f923a8c0d09f2dcebf2f..1fa1e2b8952c0b5a9043f3be65dc5bba0c3b2743 100644 --- a/internal/db/dbrepo/authcodeinforepo/authcodeInfo.go +++ b/internal/db/dbrepo/authcodeinforepo/authcodeInfo.go @@ -35,6 +35,7 @@ type AuthFlowInfoOut struct { Rotation *api.Rotation ResponseType model.ResponseType MaxTokenLen int + CodeVerifier string } type authFlowInfo struct { @@ -49,6 +50,7 @@ type authFlowInfo struct { Rotation *api.Rotation ResponseType model.ResponseType `db:"response_type"` MaxTokenLen *int `db:"max_token_len"` + CodeVerifier db.NullString `db:"code_verifier"` } func (i *AuthFlowInfo) toAuthFlowInfo() *authFlowInfo { @@ -78,6 +80,7 @@ func (i *authFlowInfo) toAuthFlowInfo() *AuthFlowInfoOut { PollingCode: bool(i.PollingCode), Rotation: i.Rotation, ResponseType: i.ResponseType, + CodeVerifier: i.CodeVerifier.String, } if i.MaxTokenLen != nil { o.MaxTokenLen = *i.MaxTokenLen @@ -111,7 +114,7 @@ func GetAuthFlowInfoByState(state *state.State) (*AuthFlowInfoOut, error) { if err := db.Transact(func(tx *sqlx.Tx) error { return errors.WithStack(tx.Get(&info, `SELECT state_h, iss, restrictions, capabilities, subtoken_capabilities, - name, polling_code, rotation, response_type, max_token_len FROM AuthInfo + name, polling_code, rotation, response_type, max_token_len, code_verifier FROM AuthInfo WHERE state_h=? AND expires_at >= CURRENT_TIMESTAMP()`, state)) }); err != nil { @@ -138,3 +141,11 @@ func UpdateTokenInfoByState(tx *sqlx.Tx, state *state.State, r restrictions.Rest return errors.WithStack(err) }) } + +// SetCodeVerifier stores the passed PKCE code verifier +func SetCodeVerifier(tx *sqlx.Tx, state *state.State, verifier string) error { + return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error { + _, err := tx.Exec(`UPDATE AuthInfo SET code_verifier=? WHERE state_h=?`, verifier, state) + return errors.WithStack(err) + }) +} diff --git a/internal/endpoints/consent/consent.go b/internal/endpoints/consent/consent.go index 6803a319c3404c4f3eaeba175d4d248843107361..b347f2859469edbac07cea5024e0527e099483e4 100644 --- a/internal/endpoints/consent/consent.go +++ b/internal/endpoints/consent/consent.go @@ -7,9 +7,12 @@ import ( "strings" "github.com/gofiber/fiber/v2" + "github.com/jmoiron/sqlx" "github.com/pkg/errors" log "github.com/sirupsen/logrus" + "github.com/oidc-mytoken/server/internal/db" + "github.com/oidc-mytoken/server/internal/oidc/pkce" "github.com/oidc-mytoken/server/internal/server/httpStatus" "github.com/oidc-mytoken/server/internal/utils/errorfmt" @@ -130,9 +133,15 @@ func HandleConsentPost(ctx *fiber.Ctx) error { }.Send(ctx) } } - if err = authcodeinforepo.UpdateTokenInfoByState( - nil, oState, req.Restrictions, req.Capabilities, req.SubtokenCapabilities, req.Rotation, req.TokenName, - ); err != nil { + pkceCode := pkce.NewS256PKCE(utils.RandASCIIString(32)) + if err = db.Transact(func(tx *sqlx.Tx) error { + if err = authcodeinforepo.UpdateTokenInfoByState( + tx, oState, req.Restrictions, req.Capabilities, req.SubtokenCapabilities, req.Rotation, req.TokenName, + ); err != nil { + return err + } + return authcodeinforepo.SetCodeVerifier(tx, oState, pkceCode.Verifier()) + }); err != nil { log.Errorf("%s", errorfmt.Full(err)) return model.ErrorToInternalServerErrorResponse(err).Send(ctx) } @@ -143,7 +152,8 @@ func HandleConsentPost(ctx *fiber.Ctx) error { Response: api.ErrorUnknownIssuer, }.Send(ctx) } - authURL := authcode.GetAuthorizationURL(provider, oState.State(), req.Restrictions) + pkceChallenge, _ := pkceCode.Challenge() + authURL := authcode.GetAuthorizationURL(provider, oState.State(), pkceChallenge, req.Restrictions) return model.Response{ Status: httpStatus.StatusOKForward, Response: map[string]string{ diff --git a/internal/oidc/authcode/authcode.go b/internal/oidc/authcode/authcode.go index 3df3b0d481b1a64617ca8ee93d36d2bf7d29af0e..9713b2cce7929c9535e836b0ddfb9fa07eb97f3b 100644 --- a/internal/oidc/authcode/authcode.go +++ b/internal/oidc/authcode/authcode.go @@ -7,13 +7,16 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/gofiber/fiber/v2" "github.com/jmoiron/sqlx" - "github.com/oidc-mytoken/server/internal/utils/cookies" - "github.com/oidc-mytoken/server/internal/utils/errorfmt" "github.com/pkg/errors" log "github.com/sirupsen/logrus" "golang.org/x/oauth2" + "github.com/oidc-mytoken/server/internal/oidc/pkce" + "github.com/oidc-mytoken/server/internal/utils/cookies" + "github.com/oidc-mytoken/server/internal/utils/errorfmt" + "github.com/oidc-mytoken/api/v0" + "github.com/oidc-mytoken/server/internal/config" "github.com/oidc-mytoken/server/internal/db" "github.com/oidc-mytoken/server/internal/db/dbrepo/accesstokenrepo" @@ -46,7 +49,7 @@ func Init() { } // GetAuthorizationURL creates a authorization url -func GetAuthorizationURL(provider *config.ProviderConf, oState string, restrictions restrictions.Restrictions) string { +func GetAuthorizationURL(provider *config.ProviderConf, oState, pkceChallenge string, restrictions restrictions.Restrictions) string { log.Debug("Generating authorization url") scopes := restrictions.GetScopes() if len(scopes) <= 0 { @@ -59,7 +62,11 @@ func GetAuthorizationURL(provider *config.ProviderConf, oState string, restricti RedirectURL: redirectURL, Scopes: scopes, } - additionalParams := []oauth2.AuthCodeOption{oauth2.ApprovalForce} + additionalParams := []oauth2.AuthCodeOption{ + oauth2.ApprovalForce, + oauth2.SetAuthURLParam("code_challenge", pkceChallenge), + oauth2.SetAuthURLParam("code_challenge_method", pkce.TransformationS256.String()), + } if issuerUtils.CompareIssuerURLs(provider.Issuer, issuer.GOOGLE) { additionalParams = append(additionalParams, oauth2.AccessTypeOffline) } else if !utils.StringInSlice(oidc.ScopeOfflineAccess, oauth2Config.Scopes) { @@ -132,6 +139,8 @@ func StartAuthCodeFlow(ctx *fiber.Ctx, oidcReq response.OIDCFlowRequest) *model. } } +//TODO don't return json + // CodeExchange performs an oidc code exchange it creates the mytoken and stores it in the database func CodeExchange(oState *state.State, code string, networkData api.ClientMetaData) *model.Response { log.Debug("Handle code exchange") @@ -159,7 +168,7 @@ func CodeExchange(oState *state.State, code string, networkData api.ClientMetaDa Endpoint: provider.Endpoints.OAuth2(), RedirectURL: redirectURL, } - token, err := oauth2Config.Exchange(context.Get(), code) + token, err := oauth2Config.Exchange(context.Get(), code, oauth2.SetAuthURLParam("code_verifier", authInfo.CodeVerifier)) if err != nil { var e *oauth2.RetrieveError if errors.As(err, &e) { diff --git a/internal/oidc/pkce/pkce.go b/internal/oidc/pkce/pkce.go new file mode 100644 index 0000000000000000000000000000000000000000..d2363eb57c6b743ed6a077b0426fb123867bf82b --- /dev/null +++ b/internal/oidc/pkce/pkce.go @@ -0,0 +1,75 @@ +package pkce + +import ( + "crypto/sha256" + "encoding/base64" + + "github.com/pkg/errors" +) + +// PKCE is a type holding the information for a PKCE flow +type PKCE struct { + verifier string + challenge string + method PKCEMethod +} + +// PKCEMethod is a type for the code challenge methods +type PKCEMethod string + +// Defines for the possible PKCEMethod +const ( + TransformationPlain = PKCEMethod("plain") + TransformationS256 = PKCEMethod("S256") +) + +func (m PKCEMethod) String() string { + return string(m) +} + +// NewPKCE creates a new PKCE for the passed verifier and PKCEMethod +func NewPKCE(verifier string, method PKCEMethod) *PKCE { + return &PKCE{ + verifier: verifier, + method: method, + } +} + +// NewS256PKCE creates a new PKCE for the passed verifier and the PKCEMethod TransformationS256 +func NewS256PKCE(verifier string) *PKCE { + return NewPKCE(verifier, TransformationS256) +} + +// Verifier returns the code_verifier +func (pkce PKCE) Verifier() string { + return pkce.verifier +} + +// Challenge returns the code_challenge according to the defined PKCEMethod +func (pkce *PKCE) Challenge() (string, error) { + var err error + if pkce.challenge == "" { + pkce.challenge, err = pkce.transform() + } + return pkce.challenge, err +} + +func (pkce PKCE) transform() (string, error) { + switch pkce.method { + case TransformationPlain: + return pkce.plain(), nil + case TransformationS256: + return pkce.s256(), nil + default: + return "", errors.New("unknown code_challenge_method") + } +} + +func (pkce PKCE) plain() string { + return pkce.verifier +} + +func (pkce PKCE) s256() string { + hash := sha256.Sum256([]byte(pkce.verifier)) + return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:]) +}