Skip to content
Snippets Groups Projects
Commit 924acb21 authored by Gabriel Zachmann's avatar Gabriel Zachmann
Browse files

implement transfer code exchange

parent 3f578e48
No related branches found
No related tags found
No related merge requests found
...@@ -19,8 +19,8 @@ type TransferCodeStatus struct { ...@@ -19,8 +19,8 @@ type TransferCodeStatus struct {
ResponseType model.ResponseType `db:"response_type"` ResponseType model.ResponseType `db:"response_type"`
} }
// CheckPollingCode checks the passed polling code in the database // CheckTransferCode checks the passed polling code in the database
func CheckPollingCode(tx *sqlx.Tx, pollingCode string) (TransferCodeStatus, error) { func CheckTransferCode(tx *sqlx.Tx, pollingCode string) (TransferCodeStatus, error) {
pt := createProxyToken(pollingCode) pt := createProxyToken(pollingCode)
var p TransferCodeStatus var p TransferCodeStatus
err := db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error { err := db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
...@@ -36,8 +36,8 @@ func CheckPollingCode(tx *sqlx.Tx, pollingCode string) (TransferCodeStatus, erro ...@@ -36,8 +36,8 @@ func CheckPollingCode(tx *sqlx.Tx, pollingCode string) (TransferCodeStatus, erro
return p, err return p, err
} }
// PopTokenForPollingCode returns the decrypted token for a polling code and then deletes the entry // PopTokenForTransferCode returns the decrypted token for a polling code and then deletes the entry
func PopTokenForPollingCode(tx *sqlx.Tx, pollingCode string) (jwt string, err error) { func PopTokenForTransferCode(tx *sqlx.Tx, pollingCode string) (jwt string, err error) {
pt := createProxyToken(pollingCode) pt := createProxyToken(pollingCode)
var valid bool var valid bool
err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error { err = db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
......
...@@ -59,12 +59,6 @@ func CreatePollingCode(pollingCode string, responseType model.ResponseType) *Tra ...@@ -59,12 +59,6 @@ func CreatePollingCode(pollingCode string, responseType model.ResponseType) *Tra
// Store stores the TransferCode in the database // Store stores the TransferCode in the database
func (tc TransferCode) Store(tx *sqlx.Tx) error { func (tc TransferCode) Store(tx *sqlx.Tx) error {
log.Debug("Storing transfer code") log.Debug("Storing transfer code")
log.WithFields(log.Fields{
"id": tc.ID(),
"token": tc.Token(),
"djwt": tc.decryptedJWT,
"ejwt": tc.encryptedJWT,
}).Trace("TransferCode")
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error { return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
if err := tc.proxyToken.Store(tx); err != nil { if err := tc.proxyToken.Store(tx); err != nil {
return err return err
......
package pkg package pkg
import (
"github.com/zachmann/mytoken/internal/model"
)
// CreateTransferCodeRequest is a request to create a new transfer code from an existing super token // CreateTransferCodeRequest is a request to create a new transfer code from an existing super token
type CreateTransferCodeRequest struct { type CreateTransferCodeRequest struct {
SuperToken string `json:"super_token"` // we use string and not token.Token because the token can also be in the Auth Header and there it is a string SuperToken string `json:"super_token"` // we use string and not token.Token because the token can also be in the Auth Header and there it is a string
} }
// ExchangeTransferCodeRequest is a request to exchange a transfer code for the super token
type ExchangeTransferCodeRequest struct {
GrantType model.GrantType `json:"grant_type"`
TransferCode string `json:"transfer_code"`
}
...@@ -25,7 +25,7 @@ func HandlePollingCode(ctx *fiber.Ctx) error { ...@@ -25,7 +25,7 @@ func HandlePollingCode(ctx *fiber.Ctx) error {
func handlePollingCode(req response.PollingCodeRequest, networkData model.ClientMetaData) *model.Response { func handlePollingCode(req response.PollingCodeRequest, networkData model.ClientMetaData) *model.Response {
pollingCode := req.PollingCode pollingCode := req.PollingCode
log.WithField("polling_code", pollingCode).Debug("Handle polling code") log.WithField("polling_code", pollingCode).Debug("Handle polling code")
pollingCodeStatus, err := transfercoderepo.CheckPollingCode(nil, pollingCode) pollingCodeStatus, err := transfercoderepo.CheckTransferCode(nil, pollingCode)
if err != nil { if err != nil {
return model.ErrorToInternalServerErrorResponse(err) return model.ErrorToInternalServerErrorResponse(err)
} }
...@@ -33,17 +33,17 @@ func handlePollingCode(req response.PollingCodeRequest, networkData model.Client ...@@ -33,17 +33,17 @@ func handlePollingCode(req response.PollingCodeRequest, networkData model.Client
log.WithField("polling_code", pollingCode).Debug("Polling code not known") log.WithField("polling_code", pollingCode).Debug("Polling code not known")
return &model.Response{ return &model.Response{
Status: fiber.StatusUnauthorized, Status: fiber.StatusUnauthorized,
Response: model.APIErrorBadPollingCode, Response: model.APIErrorBadTransferCode,
} }
} }
if pollingCodeStatus.Expired { if pollingCodeStatus.Expired {
log.WithField("polling_code", pollingCode).Debug("Polling code expired") log.WithField("polling_code", pollingCode).Debug("Polling code expired")
return &model.Response{ return &model.Response{
Status: fiber.StatusUnauthorized, Status: fiber.StatusUnauthorized,
Response: model.APIErrorPollingCodeExpired, Response: model.APIErrorTransferCodeExpired,
} }
} }
token, err := transfercoderepo.PopTokenForPollingCode(nil, pollingCode) token, err := transfercoderepo.PopTokenForTransferCode(nil, pollingCode)
if err != nil { if err != nil {
return model.ErrorToInternalServerErrorResponse(err) return model.ErrorToInternalServerErrorResponse(err)
} }
......
...@@ -38,7 +38,7 @@ func HandleSuperTokenEndpoint(ctx *fiber.Ctx) error { ...@@ -38,7 +38,7 @@ func HandleSuperTokenEndpoint(ctx *fiber.Ctx) error {
} }
case model.GrantTypeTransferCode: case model.GrantTypeTransferCode:
if config.Get().Features.TransferCodes.Enabled { if config.Get().Features.TransferCodes.Enabled {
return model.ResponseNYI.Send(ctx) return supertoken.HandleSuperTokenFromTransferCode(ctx).Send(ctx)
} }
} }
res := model.Response{ res := model.Response{
......
...@@ -22,8 +22,8 @@ var ( ...@@ -22,8 +22,8 @@ var (
APIErrorStateMismatch = APIError{ErrorInvalidRequest, "State mismatched"} APIErrorStateMismatch = APIError{ErrorInvalidRequest, "State mismatched"}
APIErrorUnsupportedOIDCFlow = APIError{ErrorInvalidGrant, "Unsupported oidc_flow"} APIErrorUnsupportedOIDCFlow = APIError{ErrorInvalidGrant, "Unsupported oidc_flow"}
APIErrorUnsupportedGrantType = APIError{ErrorInvalidGrant, "Unsupported grant_type"} APIErrorUnsupportedGrantType = APIError{ErrorInvalidGrant, "Unsupported grant_type"}
APIErrorBadPollingCode = APIError{ErrorAccessDenied, "Bad polling_code"} APIErrorBadTransferCode = APIError{ErrorInvalidToken, "Bad polling or transfer code"}
APIErrorPollingCodeExpired = APIError{ErrorExpiredToken, "polling_code is expired"} APIErrorTransferCodeExpired = APIError{ErrorExpiredToken, "polling or transfer code is expired"}
APIErrorAuthorizationPending = ErrorWithoutDescription(ErrorAuthorizationPending) APIErrorAuthorizationPending = ErrorWithoutDescription(ErrorAuthorizationPending)
APIErrorNoRefreshToken = APIError{ErrorOIDC, "Did not receive a refresh token"} APIErrorNoRefreshToken = APIError{ErrorOIDC, "Did not receive a refresh token"}
APIErrorInsufficientCapabilities = APIError{ErrorInsufficientCapabilities, "The provided token does not have the required capability for this operation"} APIErrorInsufficientCapabilities = APIError{ErrorInsufficientCapabilities, "The provided token does not have the required capability for this operation"}
......
...@@ -98,7 +98,8 @@ func NewSuperToken(oidcSub, oidcIss string, r restrictions.Restrictions, c, sc c ...@@ -98,7 +98,8 @@ func NewSuperToken(oidcSub, oidcIss string, r restrictions.Restrictions, c, sc c
return st return st
} }
func (st *SuperToken) expiresIn() uint64 { // ExpiresIn returns the amount of seconds in which this token expires
func (st *SuperToken) ExpiresIn() uint64 {
now := time.Now().Unix() now := time.Now().Unix()
expAt := st.ExpiresAt expAt := st.ExpiresAt
if expAt > 0 && expAt > now { if expAt > 0 && expAt > now {
...@@ -160,7 +161,7 @@ func (st *SuperToken) toShortSuperTokenResponse(jwt string) (response.SuperToken ...@@ -160,7 +161,7 @@ func (st *SuperToken) toShortSuperTokenResponse(jwt string) (response.SuperToken
func (st *SuperToken) toTokenResponse() response.SuperTokenResponse { func (st *SuperToken) toTokenResponse() response.SuperTokenResponse {
return response.SuperTokenResponse{ return response.SuperTokenResponse{
ExpiresIn: st.expiresIn(), ExpiresIn: st.ExpiresIn(),
Restrictions: st.Restrictions, Restrictions: st.Restrictions,
Capabilities: st.Capabilities, Capabilities: st.Capabilities,
SubtokenCapabilities: st.SubtokenCapabilities, SubtokenCapabilities: st.SubtokenCapabilities,
......
...@@ -14,6 +14,7 @@ import ( ...@@ -14,6 +14,7 @@ import (
"github.com/zachmann/mytoken/internal/db" "github.com/zachmann/mytoken/internal/db"
"github.com/zachmann/mytoken/internal/db/dbrepo/supertokenrepo" "github.com/zachmann/mytoken/internal/db/dbrepo/supertokenrepo"
dbhelper "github.com/zachmann/mytoken/internal/db/dbrepo/supertokenrepo/supertokenrepohelper" dbhelper "github.com/zachmann/mytoken/internal/db/dbrepo/supertokenrepo/supertokenrepohelper"
"github.com/zachmann/mytoken/internal/db/dbrepo/supertokenrepo/transfercoderepo"
response "github.com/zachmann/mytoken/internal/endpoints/token/super/pkg" response "github.com/zachmann/mytoken/internal/endpoints/token/super/pkg"
"github.com/zachmann/mytoken/internal/model" "github.com/zachmann/mytoken/internal/model"
"github.com/zachmann/mytoken/internal/oidc/revoke" "github.com/zachmann/mytoken/internal/oidc/revoke"
...@@ -24,9 +25,74 @@ import ( ...@@ -24,9 +25,74 @@ import (
supertoken "github.com/zachmann/mytoken/internal/supertoken/pkg" supertoken "github.com/zachmann/mytoken/internal/supertoken/pkg"
"github.com/zachmann/mytoken/internal/supertoken/restrictions" "github.com/zachmann/mytoken/internal/supertoken/restrictions"
"github.com/zachmann/mytoken/internal/supertoken/token" "github.com/zachmann/mytoken/internal/supertoken/token"
"github.com/zachmann/mytoken/internal/utils"
"github.com/zachmann/mytoken/internal/utils/ctxUtils" "github.com/zachmann/mytoken/internal/utils/ctxUtils"
) )
// HandleSuperTokenFromTransferCode handles requests to return the super token for a transfer code
func HandleSuperTokenFromTransferCode(ctx *fiber.Ctx) *model.Response {
log.Debug("Handle supertoken from transfercode")
req := response.ExchangeTransferCodeRequest{}
if err := json.Unmarshal(ctx.Body(), &req); err != nil {
return model.ErrorToBadRequestErrorResponse(err)
}
log.Trace("Parsed request")
var errorRes *model.Response = nil
var tokenStr string
if err := db.Transact(func(tx *sqlx.Tx) error {
status, err := transfercoderepo.CheckTransferCode(tx, req.TransferCode)
if err != nil {
return err
}
if !status.Found {
errorRes = &model.Response{
Status: fiber.StatusUnauthorized,
Response: model.APIErrorBadTransferCode,
}
return fmt.Errorf("error_res")
}
if status.Expired {
errorRes = &model.Response{
Status: fiber.StatusUnauthorized,
Response: model.APIErrorTransferCodeExpired,
}
return fmt.Errorf("error_res")
}
tokenStr, err = transfercoderepo.PopTokenForTransferCode(tx, req.TransferCode)
return err
}); err != nil {
if errorRes != nil {
return errorRes
}
return model.ErrorToInternalServerErrorResponse(err)
}
tokenType := model.ResponseTypeToken
if !utils.IsJWT(tokenStr) {
tokenType = model.ResponseTypeShortToken
}
jwt, err := token.GetLongSuperToken(tokenStr)
if err != nil {
return model.ErrorToInternalServerErrorResponse(err)
}
st, err := supertoken.ParseJWT(string(jwt))
if err != nil {
return model.ErrorToInternalServerErrorResponse(err)
}
return &model.Response{
Status: fiber.StatusOK,
Response: response.SuperTokenResponse{
SuperToken: tokenStr,
SuperTokenType: tokenType,
ExpiresIn: st.ExpiresIn(),
Restrictions: st.Restrictions,
Capabilities: st.Capabilities,
SubtokenCapabilities: st.SubtokenCapabilities,
},
}
}
// HandleSuperTokenFromSuperToken handles requests to create a super token from an existing super token // HandleSuperTokenFromSuperToken handles requests to create a super token from an existing super token
func HandleSuperTokenFromSuperToken(ctx *fiber.Ctx) *model.Response { func HandleSuperTokenFromSuperToken(ctx *fiber.Ctx) *model.Response {
log.Debug("Handle supertoken from supertoken") log.Debug("Handle supertoken from supertoken")
......
...@@ -12,14 +12,19 @@ import ( ...@@ -12,14 +12,19 @@ import (
type Token string type Token string
// UnmarshalJSON implements the json.Unmarshaler interface // UnmarshalJSON implements the json.Unmarshaler interface
func (t *Token) UnmarshalJSON(data []byte) error { func (t *Token) UnmarshalJSON(data []byte) (err error) {
var token string var token string
if err := json.Unmarshal(data, &token); err != nil { if err = json.Unmarshal(data, &token); err != nil {
return err return
} }
*t, err = GetLongSuperToken(token)
return
}
// GetLongSuperToken returns the long / jwt of a super token; the passed token can be a jwt or a short token
func GetLongSuperToken(token string) (Token, error) {
if utils.IsJWT(token) { if utils.IsJWT(token) {
*t = Token(token) return Token(token), nil
return nil
} }
shortToken := transfercoderepo.ParseShortToken(token) shortToken := transfercoderepo.ParseShortToken(token)
token, valid, dbErr := shortToken.JWT(nil) token, valid, dbErr := shortToken.JWT(nil)
...@@ -27,6 +32,5 @@ func (t *Token) UnmarshalJSON(data []byte) error { ...@@ -27,6 +32,5 @@ func (t *Token) UnmarshalJSON(data []byte) error {
if !valid { if !valid {
validErr = fmt.Errorf("token not valid") validErr = fmt.Errorf("token not valid")
} }
*t = Token(token) return Token(token), utils.ORErrors(dbErr, validErr)
return utils.ORErrors(dbErr, validErr)
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment