From 8cf31513a3c20ffcf5fef161cf81d43b31d699d0 Mon Sep 17 00:00:00 2001 From: zachmann <gabriel.zachmann@kit.edu> Date: Wed, 24 Feb 2021 11:09:52 +0100 Subject: [PATCH] fix some code issues --- .../mytoken-dbGarbageCollector/main.go | 20 +++++++++++-------- internal/config/config.go | 16 +++++++-------- internal/endpoints/consent/pkg/capability.go | 16 +++++++-------- .../token/super/pkg/authCodeFlowRequest.go | 5 +---- pkg/model/oidcFlow.go | 1 - pkg/mytokenlib/supertoken.go | 3 +-- shared/utils/jwtutils/jwtutils.go | 8 ++++---- shared/utils/utils.go | 7 +++---- 8 files changed, 37 insertions(+), 39 deletions(-) diff --git a/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go b/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go index 6ff8b711..eed543da 100644 --- a/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go +++ b/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go @@ -1,6 +1,7 @@ package main import ( + "github.com/jmoiron/sqlx" log "github.com/sirupsen/logrus" "github.com/oidc-mytoken/server/internal/config" @@ -11,21 +12,24 @@ import ( func main() { config.Load() loggerUtils.Init() - if err := db.Connect(); err != nil { - log.WithError(err).Fatal() - } + db.Connect() deleteExpiredTransferCodes() deleteExpiredAuthInfo() } -func deleteExpiredTransferCodes() { - if _, err := db.DB().Exec(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`); err != nil { +func execSimpleQuery(sql string) { + if err := db.RunWithinTransaction(nil, func(tx *sqlx.Tx) error { + _, err := tx.Exec(sql) + return err + }); err != nil { log.WithError(err).Error() } } +func deleteExpiredTransferCodes() { + execSimpleQuery(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`) +} + func deleteExpiredAuthInfo() { - if _, err := db.DB().Exec(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`); err != nil { - log.WithError(err).Error() - } + execSimpleQuery(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`) } diff --git a/internal/config/config.go b/internal/config/config.go index f17bc3f2..a9538b4d 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -14,7 +14,7 @@ import ( "github.com/oidc-mytoken/server/shared/utils/issuerUtils" ) -var defaultConfig = config{ +var defaultConfig = Config{ Server: serverConf{ Port: 8000, TLS: tlsConf{ @@ -70,8 +70,8 @@ var defaultConfig = config{ }, } -// config holds the server configuration -type config struct { +// Config holds the server configuration +type Config struct { IssuerURL string `yaml:"issuer"` Server serverConf `yaml:"server"` GeoIPDBFile string `yaml:"geo_ip_db_file"` @@ -166,10 +166,10 @@ type ProviderConf struct { AudienceRequestParameter string `yaml:"audience_request_parameter"` } -var conf *config +var conf *Config -// Get returns the config -func Get() *config { +// Get returns the Config +func Get() *Config { return conf } @@ -240,7 +240,7 @@ var possibleConfigLocations = []string{ "/etc/mytoken", } -// Load reads the config file and populates the config struct; then validates the config +// Load reads the config file and populates the Config struct; then validates the Config func Load() { load() if err := validate(); err != nil { @@ -258,7 +258,7 @@ func load() { } } -// LoadForSetup reads the config file and populates the config struct; it does not validate the config, since this is not required for setup +// LoadForSetup reads the config file and populates the Config struct; it does not validate the Config, since this is not required for setup func LoadForSetup() { load() } diff --git a/internal/endpoints/consent/pkg/capability.go b/internal/endpoints/consent/pkg/capability.go index aa2029c4..d5d36bab 100644 --- a/internal/endpoints/consent/pkg/capability.go +++ b/internal/endpoints/consent/pkg/capability.go @@ -5,14 +5,14 @@ import ( "github.com/oidc-mytoken/server/shared/utils" ) -type webCapability struct { +type WebCapability struct { capabilities.Capability intClass *int } -func WebCapabilities(cc capabilities.Capabilities) (wc []webCapability) { +func WebCapabilities(cc capabilities.Capabilities) (wc []WebCapability) { for _, c := range cc { - wc = append(wc, webCapability{c, nil}) + wc = append(wc, WebCapability{c, nil}) } return } @@ -28,7 +28,7 @@ var normalCapabilities = []string{"AT", "create_super_token", "tokeninfo_introsp var warningCapabilities = []string{"list_super_tokens"} var dangerCapabilities = []string{"settings"} -func (c webCapability) getIntClass() int { +func (c WebCapability) getIntClass() int { if c.intClass != nil { return *c.intClass } @@ -47,15 +47,15 @@ func (c webCapability) getIntClass() int { return -1 } -func (c webCapability) getDangerLevel() int { +func (c WebCapability) getDangerLevel() int { return c.getIntClass() } -func (c webCapability) ColorClass() string { +func (c WebCapability) ColorClass() string { return textColorByDanger(c.getDangerLevel()) } -func (c webCapability) CapabilityLevel() string { +func (c WebCapability) CapabilityLevel() string { intClass := c.getIntClass() switch intClass { case 0: @@ -68,6 +68,6 @@ func (c webCapability) CapabilityLevel() string { return "" } -func (c webCapability) IsCreateST() bool { +func (c WebCapability) IsCreateST() bool { return c.Name == capabilities.CapabilityCreateST.Name } diff --git a/internal/endpoints/token/super/pkg/authCodeFlowRequest.go b/internal/endpoints/token/super/pkg/authCodeFlowRequest.go index fb22f0b1..bbcaea57 100644 --- a/internal/endpoints/token/super/pkg/authCodeFlowRequest.go +++ b/internal/endpoints/token/super/pkg/authCodeFlowRequest.go @@ -18,10 +18,7 @@ type AuthCodeFlowRequest struct { // Native checks if the request is native func (r *AuthCodeFlowRequest) Native() bool { - if r.RedirectType == redirectTypeNative { - return true - } - return false + return r.RedirectType == redirectTypeNative } // UnmarshalJSON implements the json unmarshaler interface diff --git a/pkg/model/oidcFlow.go b/pkg/model/oidcFlow.go index 20c71620..cb3f923f 100644 --- a/pkg/model/oidcFlow.go +++ b/pkg/model/oidcFlow.go @@ -78,7 +78,6 @@ func (f OIDCFlow) AddToSliceIfNotFound(s *[]OIDCFlow) { return } *s = append(*s, f) - return } // OIDCFlowIsInSlice checks if a OIDCFlow is present in a slice of OIDCFlows diff --git a/pkg/mytokenlib/supertoken.go b/pkg/mytokenlib/supertoken.go index 3dd52f0d..f2b49b54 100644 --- a/pkg/mytokenlib/supertoken.go +++ b/pkg/mytokenlib/supertoken.go @@ -141,8 +141,7 @@ func (my *Mytoken) PollOnce(pollingCode string) (string, bool, error) { } var myErr *MytokenError if errors.As(err, &myErr) { - switch myErr.err { - case model.ErrorAuthorizationPending: + if myErr.err == model.ErrorAuthorizationPending { err = nil } } diff --git a/shared/utils/jwtutils/jwtutils.go b/shared/utils/jwtutils/jwtutils.go index 5f60c393..acf0aed1 100644 --- a/shared/utils/jwtutils/jwtutils.go +++ b/shared/utils/jwtutils/jwtutils.go @@ -50,14 +50,14 @@ func GetAudiencesFromJWT(token string) ([]string, bool) { return nil, false } auds := res[0].Value - switch auds.(type) { + switch v := auds.(type) { case string: - return []string{auds.(string)}, true + return []string{v}, true case []string: - return auds.([]string), true + return v, true case []interface{}: strs := []string{} - for _, s := range auds.([]interface{}) { + for _, s := range v { str, ok := s.(string) if !ok { return nil, false diff --git a/shared/utils/utils.go b/shared/utils/utils.go index 7b694d53..c0396657 100644 --- a/shared/utils/utils.go +++ b/shared/utils/utils.go @@ -111,11 +111,10 @@ func IPIsIn(ip string, ips []string) bool { if ipNetB != nil && ipNetB.Contains(ipA) { return true } - } else { - if ip == ipp { - return true - } + } else if ip == ipp { + return true } + } return false } -- GitLab