Skip to content
Snippets Groups Projects
Commit e2ccdfed authored by Gabriel Zachmann's avatar Gabriel Zachmann
Browse files

WIP auth code exchange

parent a3e58c1d
No related branches found
No related tags found
No related merge requests found
Showing
with 258 additions and 25 deletions
issuer: "https://mytoken.example.com"
server:
hostname: "mytoken.example.com"
signing_key_file: "/mytoken.key"
database:
host: "localhost"
user: "mytoken"
......@@ -24,5 +23,7 @@ enabled_super_token_endpoint_grant_types:
- "access_token"
- "private_key_jwt"
polling_code_expires_after: 300
token_signing_alg: "RS512"
signing:
alg: "RS512"
key_file: "/mytoken.key"
service_documentation: "https://github.com/zachmann/mytoken"
......@@ -23,10 +23,9 @@ type Config struct {
Providers []*ProviderConf `yaml:"providers"`
ProviderByIssuer map[string]*ProviderConf `yaml:"-"`
IssuerURL string `yaml:"issuer"`
SigningKeyFile string `yaml:"signing_key_file"`
EnabledOIDCFlows []model.OIDCFlow `yaml:"enabled_oidc_flows"`
EnabledSuperTokenEndpointGrantTypes []model.GrantType `yaml:"enabled_super_token_endpoint_grant_types"`
TokenSigningAlg string `yaml:"token_signing_alg"`
Signing signingConf `yaml:"signing"`
ServiceDocumentation string `yaml:"service_documentation"`
PollingCodeExpiresAfter int64 `yaml:"polling_code_expires_after"`
}
......@@ -42,6 +41,11 @@ type serverConf struct {
Hostname string `yaml:"hostname"`
}
type signingConf struct {
Alg string `yaml:"alg"`
KeyFile string `yaml:"key_file"`
}
// ProviderConf holds information about a provider
type ProviderConf struct {
Issuer string `yaml:"issuer"`
......@@ -94,10 +98,10 @@ func validate() error {
if conf.IssuerURL == "" {
return fmt.Errorf("invalid config: issuerurl not set")
}
if conf.SigningKeyFile == "" {
if conf.Signing.KeyFile == "" {
return fmt.Errorf("invalid config: signingkeyfile not set")
}
if conf.TokenSigningAlg == "" {
if conf.Signing.Alg == "" {
return fmt.Errorf("invalid config: tokensigningalg not set")
}
model.OIDCFlowAuthorizationCode.AddToSliceIfNotFound(conf.EnabledOIDCFlows)
......
package dbModels
import (
"database/sql"
"github.com/jmoiron/sqlx"
uuid "github.com/satori/go.uuid"
"github.com/zachmann/mytoken/internal/db"
"github.com/zachmann/mytoken/internal/supertoken/capabilities"
)
type AccessToken struct {
Token string
IP string `db:"ip_created"`
Comment string
STID uuid.UUID `db:"ST_id"`
Scopes []string
Capabilities capabilities.Capabilities
Audiences []string
}
type accessToken struct {
Token string
IP string
Comment sql.NullString
STID uuid.UUID
ScopeAttr []Attribute
CapabilityAttr []Attribute
AudienceAttr []Attribute
}
func (t *AccessToken) Store() error {
return db.Transact(func(tx *sqlx.Tx) error {
res, err := tx.NamedExec(`INSERT INTO AccessTokens (token, ip_created, comment, ST_id) VALUES (:token, :ip_created, :comment, :ST_id)`, t)
if err != nil {
return err
}
atID, err := res.LastInsertId()
if err != nil {
return err
}
if _, err := tx.NamedExec(`INSERT INTO AT_Attributes (AT_id, attribute_id, attribute) VALUES (:token, :ip_created, :comment, :ST_id)`, t); err != nil {
return err
}
return nil
})
}
package dbModels
type Attribute struct {
ATID uint64 `db:"AT_id"`
AttrID uint64 `db:"attribute_id"`
Attr string `db:"attribute"`
}
......@@ -17,7 +17,7 @@ type AuthFlowInfo struct {
Restrictions restrictions.Restrictions
Capabilities capabilities.Capabilities
Name string
PollingCode string
PollingCode string `db:"polling_code"`
}
type authFlowInfo struct {
......@@ -59,3 +59,11 @@ func (e *AuthFlowInfo) Store() error {
return err
})
}
func GetAuthCodeInfoByState(state string) (info AuthFlowInfo, err error) {
if e := db.DB().Get(&info, `SELECT * FROM AuthInfoV WHERE state=?`, state); e != nil {
err = e
return
}
return
}
......@@ -4,6 +4,8 @@ import (
"database/sql"
"time"
"github.com/jmoiron/sqlx"
"github.com/zachmann/mytoken/internal/supertoken/capabilities"
"github.com/zachmann/mytoken/internal/supertoken/restrictions"
......@@ -30,10 +32,7 @@ type SuperTokenEntry struct {
IP string `db:"ip_created"`
}
func NewSuperTokenEntry(name, oidcSub, oidcIss string, r restrictions.Restrictions, c capabilities.Capabilities) *SuperTokenEntry {
//TODO
ip := "192.168.0.31"
func NewSuperTokenEntry(name, oidcSub, oidcIss string, r restrictions.Restrictions, c capabilities.Capabilities, ip string) *SuperTokenEntry {
st := supertoken.NewSuperToken(oidcSub, oidcIss, r, c)
return &SuperTokenEntry{
ID: st.ID,
......@@ -84,6 +83,24 @@ type superTokenEntryStore struct {
}
func (e *superTokenEntryStore) Store() error {
_, err := db.DB().NamedExec(`INSERT INTO SuperTokens (id, parent_id, root_id, revoked, token, refresh_token, name, ip_created, user_id) VALUES(:id, :parent_id, :root_id, :revoked, :token, :refresh_token, :name, :ip_created, (SELECT id FROM Users WHERE iss=:iss AND sub=:sub))`, e)
return err
stmt, err := db.DB().PrepareNamed(`INSERT INTO SuperTokens (id, parent_id, root_id, revoked, token, refresh_token, name, ip_created, user_id) VALUES(:id, :parent_id, :root_id, :revoked, :token, :refresh_token, :name, :ip_created, (SELECT id FROM Users WHERE iss=:iss AND sub=:sub))`)
if err != nil {
return err
}
return db.Transact(func(tx *sqlx.Tx) error {
txStmt := tx.NamedStmt(stmt)
_, err := txStmt.Exec(e)
if err != nil {
if err.Error() == "correct" {
_, err = tx.NamedExec(`INSERT INTO Users (sub, iss) VALUES(:sub, :iss)`, e)
if err != nil {
return err
}
_, err = txStmt.Exec(e)
return err
}
return err
}
return nil
})
}
......@@ -33,7 +33,7 @@ func Init() {
RevocationEndpoint: utils.CombineURLPath(config.Get().IssuerURL, apiPath.CURRENT, "/revocation"),
JWKSURI: utils.CombineURLPath(config.Get().IssuerURL, "/jwks"),
ProvidersSupported: getProvidersFromConfig(),
TokenSigningAlgValue: config.Get().TokenSigningAlg,
TokenSigningAlgValue: config.Get().Signing.Alg,
AccessTokenEndpointGrantTypesSupported: []model.GrantType{model.GrantTypeSuperToken},
SuperTokenEndpointGrantTypesSupported: config.Get().EnabledSuperTokenEndpointGrantTypes,
SuperTokenEndpointOIDCFlowsSupported: config.Get().EnabledOIDCFlows,
......
package redirect
import (
"fmt"
"github.com/zachmann/mytoken/internal/oidc/authcode"
"github.com/gofiber/fiber/v2"
)
func HandleOIDCRedirect(ctx *fiber.Ctx) error {
error := ctx.Params("error")
state := ctx.Params("state")
if error != "" {
errorDescription := ctx.Params("error_description")
if errorDescription != "" {
error = fmt.Sprintf("%s: %s", error, errorDescription)
}
if state != "" {
//TODO delete AuthInfo (and pollingCode)
}
ctx.SendStatus(fiber.StatusBadRequest)
return ctx.SendString(fmt.Sprintf("error: %s", error))
}
code := ctx.Params("code")
res := authcode.CodeExchange(state, code, ctx.IP())
return res.Send(ctx)
}
......@@ -7,8 +7,8 @@ import (
type authCodeFlowResponse struct {
AuthorizationURL string `json:"authorization_url"`
PollingCode string `json:"polling_code"`
PollingCodeExpires int64 `json:"polling_code_expires"`
PollingCode string `json:"polling_code,omitempty"`
PollingCodeExpires int64 `json:"polling_code_expires,omitempty"`
}
type AuthCodeFlowResponse struct {
......@@ -19,9 +19,11 @@ type AuthCodeFlowResponse struct {
func (r AuthCodeFlowResponse) MarshalJSON() ([]byte, error) {
rr := authCodeFlowResponse{
AuthorizationURL: r.AuthorizationURL,
PollingCode: r.PollingCode,
PollingCodeExpires: r.PollingCodeExpires.Unix(),
AuthorizationURL: r.AuthorizationURL,
PollingCode: r.PollingCode,
}
if rr.PollingCode != "" {
rr.PollingCodeExpires = r.PollingCodeExpires.Unix()
}
return json.Marshal(rr)
}
......
......@@ -13,8 +13,8 @@ import (
// GenerateRSAKeyPair generates an RSA key pair
func GenerateRSAKeyPair() (*rsa.PrivateKey, *rsa.PublicKey) {
privkey, _ := rsa.GenerateKey(rand.Reader, 4096)
return privkey, &privkey.PublicKey
sk, _ := rsa.GenerateKey(rand.Reader, 1024)
return sk, &sk.PublicKey
}
// ExportRSAPrivateKeyAsPemStr exports the private key
......@@ -59,7 +59,7 @@ func GetPublicKey() *rsa.PublicKey {
// Init does init
func Init() {
keyFileContent, err := ioutil.ReadFile(config.Get().SigningKeyFile)
keyFileContent, err := ioutil.ReadFile(config.Get().Signing.KeyFile)
if err != nil {
panic(err)
}
......
package model
type APIError struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description,omitempty"`
}
// Predefined errors
var (
APIErrorUnknownIssuer = APIError{ErrorInvalidRequest, "The provided issuer is not supported"}
APIErrorStateMismatch = APIError{ErrorInvalidRequest, "State mismatched"}
)
// Predefined OAuth2/OIDC errors
const (
ErrorInvalidRequest = "invalid_request"
ErrorInvalidClient = "invalid_client"
ErrorInvalidGrant = "invalid_grant"
ErrorUnauthorizedClient = "unauthorized_client"
ErrorUnsupportedGrantType = "unsupported_grant_type"
ErrorInvalidScope = "invalid_scope"
ErrorInvalidToken = "invalid_token"
ErrorInsufficientScope = "insufficient_scope"
)
// Additional Mytoken errors
const (
ErrorInternal = "internal_server_error"
)
func InternalServerError(errorDescription string) APIError {
return APIError{
Error: ErrorInternal,
ErrorDescription: errorDescription,
}
}
package model
import "github.com/gofiber/fiber/v2"
type Response struct {
Status int
Response interface{}
}
func (r *Response) Send(ctx *fiber.Ctx) error {
return ctx.Status(r.Status).JSON(r.Response)
}
func ErrorToInternalServerErrorResponse(err error) Response {
return Response{
Status: fiber.StatusInternalServerError,
Response: InternalServerError(err.Error()),
}
}
package authcode
import (
"context"
"fmt"
"log"
"time"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gofiber/fiber/v2"
"github.com/zachmann/mytoken/internal/model"
"github.com/zachmann/mytoken/internal/utils/issuerUtils"
"github.com/zachmann/mytoken/internal/config"
......@@ -91,3 +98,56 @@ func InitAuthCodeFlow(provider *config.ProviderConf, req *response.AuthCodeFlowR
}
return
}
func CodeExchange(state, code, ip string) model.Response {
log.Print("Handle code exchange")
authInfo, err := dbModels.GetAuthCodeInfoByState(state)
if err != nil {
return model.Response{
Status: fiber.StatusBadRequest,
Response: model.APIErrorStateMismatch,
}
}
provider, ok := config.Get().ProviderByIssuer[authInfo.Issuer]
if !ok {
return model.Response{
Status: fiber.StatusBadRequest,
Response: model.APIErrorUnknownIssuer,
}
}
oauth2Config := oauth2.Config{
ClientID: provider.ClientID,
ClientSecret: provider.ClientSecret,
Endpoint: provider.Provider.Endpoint(),
}
token, err := oauth2Config.Exchange(context.Background(), code)
if err != nil {
return model.ErrorToInternalServerErrorResponse(err)
}
oidcSub, err := getSubjectFromUserinfo(provider.Provider, token)
if err != nil {
return model.ErrorToInternalServerErrorResponse(err)
}
ste, err := createSuperTokenEntry(authInfo, token, oidcSub, ip)
if err != nil {
return model.ErrorToInternalServerErrorResponse(err)
}
}
func createSuperTokenEntry(authFlowInfo dbModels.AuthFlowInfo, token *oauth2.Token, oidcSub, ip string) (*dbModels.SuperTokenEntry, error) {
ste := dbModels.NewSuperTokenEntry(authFlowInfo.Name, oidcSub, authFlowInfo.Issuer, authFlowInfo.Restrictions, authFlowInfo.Capabilities, ip)
ste.RefreshToken = token.RefreshToken
err := ste.Store("Used grant_type oidc_flow authorization_code")
if err != nil {
return nil, err
}
return ste, nil
}
func getSubjectFromUserinfo(provider *oidc.Provider, token *oauth2.Token) (string, error) {
userInfo, err := provider.UserInfo(context.Background(), oauth2.StaticTokenSource(token))
if err != nil {
return "", fmt.Errorf("failed to get userinfo: %s", err)
}
return userInfo.Subject, nil
}
......@@ -84,7 +84,7 @@ func (st *SuperToken) Valid() error {
// ToJWT returns the SuperToken as JWT
func (st *SuperToken) ToJWT() (string, error) {
return jwt.NewWithClaims(jwt.GetSigningMethod(config.Get().TokenSigningAlg), st).SignedString(jws.GetPrivateKey())
return jwt.NewWithClaims(jwt.GetSigningMethod(config.Get().Signing.Alg), st).SignedString(jws.GetPrivateKey())
}
// Value implements the driver.Valuer interface.
......
......@@ -24,10 +24,10 @@ func (t *SuperTokenEntryTree) print(level int) {
}
}
func NewSuperTokenEntryFromSuperToken(name string, parent dbModels.SuperTokenEntry, r restrictions.Restrictions, c capabilities.Capabilities) (*dbModels.SuperTokenEntry, error) {
func NewSuperTokenEntryFromSuperToken(name string, parent dbModels.SuperTokenEntry, r restrictions.Restrictions, c capabilities.Capabilities, ip string) (*dbModels.SuperTokenEntry, error) {
newRestrictions := restrictions.Tighten(parent.Token.Restrictions, r)
newCapabilities := capabilities.Tighten(parent.Token.Capabilities, c)
ste := dbModels.NewSuperTokenEntry(name, parent.Token.OIDCSubject, parent.Token.OIDCIssuer, newRestrictions, newCapabilities)
ste := dbModels.NewSuperTokenEntry(name, parent.Token.OIDCSubject, parent.Token.OIDCIssuer, newRestrictions, newCapabilities, ip)
ste.ParentID = parent.ID.String()
rootID := parent.ID.String()
if !parent.Root() {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment