-
Gabriel Zachmann authored
* refactor shared utils * fix duplicated imports
Gabriel Zachmann authored* refactor shared utils * fix duplicated imports
auther.go 6.88 KiB
package auth
import (
"github.com/gofiber/fiber/v2"
"github.com/jmoiron/sqlx"
"github.com/oidc-mytoken/api/v0"
log "github.com/sirupsen/logrus"
"github.com/oidc-mytoken/server/internal/config"
dbhelper "github.com/oidc-mytoken/server/internal/db/dbrepo/mytokenrepo/mytokenrepohelper"
"github.com/oidc-mytoken/server/internal/model"
mytoken "github.com/oidc-mytoken/server/internal/mytoken/pkg"
"github.com/oidc-mytoken/server/internal/mytoken/restrictions"
"github.com/oidc-mytoken/server/internal/mytoken/universalmytoken"
"github.com/oidc-mytoken/server/internal/utils/ctxutils"
"github.com/oidc-mytoken/server/internal/utils/errorfmt"
)
// RequireGrantType checks that the passed model.GrantType are the same, and returns an error model.Response if not
func RequireGrantType(rlog log.Ext1FieldLogger, want, got model.GrantType) *model.Response {
if got != want {
return &model.Response{
Status: fiber.StatusBadRequest,
Response: api.ErrorUnsupportedGrantType,
}
}
rlog.Trace("Checked grant type")
return nil
}
// RequireMytoken checks the passed universalmytoken.UniversalMytoken and if needed other request parameters like
// authorization header and cookie value for a mytoken string. The mytoken string is parsed and if not valid an error
// model.Response is returned.
func RequireMytoken(rlog log.Ext1FieldLogger, reqToken *universalmytoken.UniversalMytoken, ctx *fiber.Ctx) (
*mytoken.Mytoken, *model.Response,
) {
if reqToken.JWT == "" {
t, found := ctxutils.GetMytoken(ctx)
if t == nil {
errDesc := "no mytoken found in request"
if found {
errDesc = "token not valid"
}
return nil, &model.Response{
Status: fiber.StatusUnauthorized,
Response: model.InvalidTokenError(errDesc),
}
}
*reqToken = *t
}
mt, err := mytoken.ParseJWT(reqToken.JWT)
if err != nil {
return nil, &model.Response{
Status: fiber.StatusUnauthorized,
Response: model.InvalidTokenError(errorfmt.Error(err)),
}
}
rlog.Trace("Parsed mytoken")
return mt, nil
}
// RequireMytokenNotRevoked checks that the passed mytoken.Mytoken was not revoked, if it was an error model.Response is
// returned.
func RequireMytokenNotRevoked(rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken) *model.Response {
revoked, dbErr := dbhelper.CheckTokenRevoked(rlog, tx, mt.ID, mt.SeqNo, mt.Rotation)
if dbErr != nil {
rlog.Errorf("%s", errorfmt.Full(dbErr))
return model.ErrorToInternalServerErrorResponse(dbErr)
}
if revoked {
return &model.Response{
Status: fiber.StatusUnauthorized,
Response: model.InvalidTokenError(""),
}
}
rlog.Trace("Checked mytoken not revoked")
return nil
}
// RequireValidMytoken checks the passed universalmytoken.UniversalMytoken and if needed other request parameters like
// authorization header and cookie value for a mytoken string. The mytoken string is parsed and if not valid an error
// model.Response is returned. RequireValidMytoken also asserts that the mytoken.Mytoken was not revoked.
func RequireValidMytoken(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, reqToken *universalmytoken.UniversalMytoken, ctx *fiber.Ctx,
) (
*mytoken.Mytoken, *model.Response,
) {
mt, errRes := RequireMytoken(rlog, reqToken, ctx)
if errRes != nil {
return nil, errRes
}
return mt, RequireMytokenNotRevoked(rlog, tx, mt)
}
// RequireMatchingIssuer checks that the OIDC issuer from a mytoken is the same as the issuer string in a request (if
// given). RequireMatchingIssuer also checks that the issuer is valid for this mytoken instance.
func RequireMatchingIssuer(rlog log.Ext1FieldLogger, mtOIDCIssuer string, requestIssuer *string) (
*config.ProviderConf, *model.Response,
) {
if *requestIssuer == "" {
*requestIssuer = mtOIDCIssuer
rlog.Trace("Checked issuer (was not given)")
}
if *requestIssuer != mtOIDCIssuer {
return nil, &model.Response{
Status: fiber.StatusBadRequest,
Response: model.BadRequestError("token not for specified issuer"),
}
}
provider, ok := config.Get().ProviderByIssuer[*requestIssuer]
if !ok {
return nil, &model.Response{
Status: fiber.StatusBadRequest,
Response: api.ErrorUnknownIssuer,
}
}
rlog.Trace("Checked issuer")
return provider, nil
}
// RequireCapability checks that the passed mytoken.Mytoken has the required api.Capability and returns an error
// model.Response if not
func RequireCapability(rlog log.Ext1FieldLogger, capability api.Capability, mt *mytoken.Mytoken) *model.Response {
if !mt.Capabilities.Has(capability) {
return &model.Response{
Status: fiber.StatusForbidden,
Response: api.ErrorInsufficientCapabilities,
}
}
rlog.Trace("Checked capability")
return nil
}
func requireUseableRestriction(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken, ip string, scopes, auds []string, at bool,
) (*restrictions.Restriction, *model.Response) {
if len(mt.Restrictions) == 0 {
return nil, nil
}
getUseableRestrictions := mt.Restrictions.GetValidForOther
if at {
getUseableRestrictions = mt.Restrictions.GetValidForAT
}
// WithScopes and WithAudience don't tighten the restrictions if nil is passed
useableRestrictions := getUseableRestrictions(rlog, tx, ip, mt.ID).WithScopes(rlog, scopes).WithAudiences(
rlog, auds,
)
if len(useableRestrictions) == 0 {
return nil, &model.Response{
Status: fiber.StatusForbidden,
Response: api.ErrorUsageRestricted,
}
}
rlog.Trace("Checked mytoken restrictions")
return useableRestrictions[0], nil
}
// RequireUsableRestriction checks that the mytoken.Mytoken's restrictions allow the usage
func RequireUsableRestriction(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken, ip string, scopes, auds []string,
capability api.Capability,
) (*restrictions.Restriction, *model.Response) {
return requireUseableRestriction(rlog, tx, mt, ip, scopes, auds, capability == api.CapabilityAT)
}
// RequireUsableRestrictionAT checks that the mytoken.Mytoken's restrictions allow the AT usage
func RequireUsableRestrictionAT(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken, ip string, scopes, auds []string,
) (*restrictions.Restriction, *model.Response) {
return requireUseableRestriction(rlog, tx, mt, ip, scopes, auds, true)
}
// RequireUsableRestrictionOther checks that the mytoken.Mytoken's restrictions allow the non-AT usage
func RequireUsableRestrictionOther(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken, ip string, scopes, auds []string,
) (*restrictions.Restriction, *model.Response) {
return requireUseableRestriction(rlog, tx, mt, ip, scopes, auds, false)
}
// CheckCapabilityAndRestriction checks the mytoken.Mytoken's capability and restrictions
func CheckCapabilityAndRestriction(
rlog log.Ext1FieldLogger, tx *sqlx.Tx, mt *mytoken.Mytoken, ip string, scopes, auds []string,
capability api.Capability,
) (*restrictions.Restriction, *model.Response) {
if errRes := RequireCapability(rlog, capability, mt); errRes != nil {
return nil, errRes
}
return RequireUsableRestriction(rlog, tx, mt, ip, scopes, auds, capability)
}