From 8cf31513a3c20ffcf5fef161cf81d43b31d699d0 Mon Sep 17 00:00:00 2001
From: zachmann <gabriel.zachmann@kit.edu>
Date: Wed, 24 Feb 2021 11:09:52 +0100
Subject: [PATCH] fix some code issues

---
 .../mytoken-dbGarbageCollector/main.go        | 20 +++++++++++--------
 internal/config/config.go                     | 16 +++++++--------
 internal/endpoints/consent/pkg/capability.go  | 16 +++++++--------
 .../token/super/pkg/authCodeFlowRequest.go    |  5 +----
 pkg/model/oidcFlow.go                         |  1 -
 pkg/mytokenlib/supertoken.go                  |  3 +--
 shared/utils/jwtutils/jwtutils.go             |  8 ++++----
 shared/utils/utils.go                         |  7 +++----
 8 files changed, 37 insertions(+), 39 deletions(-)

diff --git a/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go b/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go
index 6ff8b711..eed543da 100644
--- a/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go
+++ b/cmd/mytoken-server/mytoken-dbGarbageCollector/main.go
@@ -1,6 +1,7 @@
 package main
 
 import (
+	"github.com/jmoiron/sqlx"
 	log "github.com/sirupsen/logrus"
 
 	"github.com/oidc-mytoken/server/internal/config"
@@ -11,21 +12,24 @@ import (
 func main() {
 	config.Load()
 	loggerUtils.Init()
-	if err := db.Connect(); err != nil {
-		log.WithError(err).Fatal()
-	}
+	db.Connect()
 	deleteExpiredTransferCodes()
 	deleteExpiredAuthInfo()
 }
 
-func deleteExpiredTransferCodes() {
-	if _, err := db.DB().Exec(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`); err != nil {
+func execSimpleQuery(sql string) {
+	if err := db.RunWithinTransaction(nil, func(tx *sqlx.Tx) error {
+		_, err := tx.Exec(sql)
+		return err
+	}); err != nil {
 		log.WithError(err).Error()
 	}
 }
 
+func deleteExpiredTransferCodes() {
+	execSimpleQuery(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`)
+}
+
 func deleteExpiredAuthInfo() {
-	if _, err := db.DB().Exec(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`); err != nil {
-		log.WithError(err).Error()
-	}
+	execSimpleQuery(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`)
 }
diff --git a/internal/config/config.go b/internal/config/config.go
index f17bc3f2..a9538b4d 100644
--- a/internal/config/config.go
+++ b/internal/config/config.go
@@ -14,7 +14,7 @@ import (
 	"github.com/oidc-mytoken/server/shared/utils/issuerUtils"
 )
 
-var defaultConfig = config{
+var defaultConfig = Config{
 	Server: serverConf{
 		Port: 8000,
 		TLS: tlsConf{
@@ -70,8 +70,8 @@ var defaultConfig = config{
 	},
 }
 
-// config holds the server configuration
-type config struct {
+// Config holds the server configuration
+type Config struct {
 	IssuerURL            string                   `yaml:"issuer"`
 	Server               serverConf               `yaml:"server"`
 	GeoIPDBFile          string                   `yaml:"geo_ip_db_file"`
@@ -166,10 +166,10 @@ type ProviderConf struct {
 	AudienceRequestParameter string             `yaml:"audience_request_parameter"`
 }
 
-var conf *config
+var conf *Config
 
-// Get returns the config
-func Get() *config {
+// Get returns the Config
+func Get() *Config {
 	return conf
 }
 
@@ -240,7 +240,7 @@ var possibleConfigLocations = []string{
 	"/etc/mytoken",
 }
 
-// Load reads the config file and populates the config struct; then validates the config
+// Load reads the config file and populates the Config struct; then validates the Config
 func Load() {
 	load()
 	if err := validate(); err != nil {
@@ -258,7 +258,7 @@ func load() {
 	}
 }
 
-// LoadForSetup reads the config file and populates the config struct; it does not validate the config, since this is not required for setup
+// LoadForSetup reads the config file and populates the Config struct; it does not validate the Config, since this is not required for setup
 func LoadForSetup() {
 	load()
 }
diff --git a/internal/endpoints/consent/pkg/capability.go b/internal/endpoints/consent/pkg/capability.go
index aa2029c4..d5d36bab 100644
--- a/internal/endpoints/consent/pkg/capability.go
+++ b/internal/endpoints/consent/pkg/capability.go
@@ -5,14 +5,14 @@ import (
 	"github.com/oidc-mytoken/server/shared/utils"
 )
 
-type webCapability struct {
+type WebCapability struct {
 	capabilities.Capability
 	intClass *int
 }
 
-func WebCapabilities(cc capabilities.Capabilities) (wc []webCapability) {
+func WebCapabilities(cc capabilities.Capabilities) (wc []WebCapability) {
 	for _, c := range cc {
-		wc = append(wc, webCapability{c, nil})
+		wc = append(wc, WebCapability{c, nil})
 	}
 	return
 }
@@ -28,7 +28,7 @@ var normalCapabilities = []string{"AT", "create_super_token", "tokeninfo_introsp
 var warningCapabilities = []string{"list_super_tokens"}
 var dangerCapabilities = []string{"settings"}
 
-func (c webCapability) getIntClass() int {
+func (c WebCapability) getIntClass() int {
 	if c.intClass != nil {
 		return *c.intClass
 	}
@@ -47,15 +47,15 @@ func (c webCapability) getIntClass() int {
 	return -1
 }
 
-func (c webCapability) getDangerLevel() int {
+func (c WebCapability) getDangerLevel() int {
 	return c.getIntClass()
 }
 
-func (c webCapability) ColorClass() string {
+func (c WebCapability) ColorClass() string {
 	return textColorByDanger(c.getDangerLevel())
 }
 
-func (c webCapability) CapabilityLevel() string {
+func (c WebCapability) CapabilityLevel() string {
 	intClass := c.getIntClass()
 	switch intClass {
 	case 0:
@@ -68,6 +68,6 @@ func (c webCapability) CapabilityLevel() string {
 	return ""
 }
 
-func (c webCapability) IsCreateST() bool {
+func (c WebCapability) IsCreateST() bool {
 	return c.Name == capabilities.CapabilityCreateST.Name
 }
diff --git a/internal/endpoints/token/super/pkg/authCodeFlowRequest.go b/internal/endpoints/token/super/pkg/authCodeFlowRequest.go
index fb22f0b1..bbcaea57 100644
--- a/internal/endpoints/token/super/pkg/authCodeFlowRequest.go
+++ b/internal/endpoints/token/super/pkg/authCodeFlowRequest.go
@@ -18,10 +18,7 @@ type AuthCodeFlowRequest struct {
 
 // Native checks if the request is native
 func (r *AuthCodeFlowRequest) Native() bool {
-	if r.RedirectType == redirectTypeNative {
-		return true
-	}
-	return false
+	return r.RedirectType == redirectTypeNative
 }
 
 // UnmarshalJSON implements the json unmarshaler interface
diff --git a/pkg/model/oidcFlow.go b/pkg/model/oidcFlow.go
index 20c71620..cb3f923f 100644
--- a/pkg/model/oidcFlow.go
+++ b/pkg/model/oidcFlow.go
@@ -78,7 +78,6 @@ func (f OIDCFlow) AddToSliceIfNotFound(s *[]OIDCFlow) {
 		return
 	}
 	*s = append(*s, f)
-	return
 }
 
 // OIDCFlowIsInSlice checks if a OIDCFlow is present in a slice of OIDCFlows
diff --git a/pkg/mytokenlib/supertoken.go b/pkg/mytokenlib/supertoken.go
index 3dd52f0d..f2b49b54 100644
--- a/pkg/mytokenlib/supertoken.go
+++ b/pkg/mytokenlib/supertoken.go
@@ -141,8 +141,7 @@ func (my *Mytoken) PollOnce(pollingCode string) (string, bool, error) {
 	}
 	var myErr *MytokenError
 	if errors.As(err, &myErr) {
-		switch myErr.err {
-		case model.ErrorAuthorizationPending:
+		if myErr.err == model.ErrorAuthorizationPending {
 			err = nil
 		}
 	}
diff --git a/shared/utils/jwtutils/jwtutils.go b/shared/utils/jwtutils/jwtutils.go
index 5f60c393..acf0aed1 100644
--- a/shared/utils/jwtutils/jwtutils.go
+++ b/shared/utils/jwtutils/jwtutils.go
@@ -50,14 +50,14 @@ func GetAudiencesFromJWT(token string) ([]string, bool) {
 		return nil, false
 	}
 	auds := res[0].Value
-	switch auds.(type) {
+	switch v := auds.(type) {
 	case string:
-		return []string{auds.(string)}, true
+		return []string{v}, true
 	case []string:
-		return auds.([]string), true
+		return v, true
 	case []interface{}:
 		strs := []string{}
-		for _, s := range auds.([]interface{}) {
+		for _, s := range v {
 			str, ok := s.(string)
 			if !ok {
 				return nil, false
diff --git a/shared/utils/utils.go b/shared/utils/utils.go
index 7b694d53..c0396657 100644
--- a/shared/utils/utils.go
+++ b/shared/utils/utils.go
@@ -111,11 +111,10 @@ func IPIsIn(ip string, ips []string) bool {
 			if ipNetB != nil && ipNetB.Contains(ipA) {
 				return true
 			}
-		} else {
-			if ip == ipp {
-				return true
-			}
+		} else if ip == ipp {
+			return true
 		}
+
 	}
 	return false
 }
-- 
GitLab