Skip to content
Snippets Groups Projects
Commit 8cf31513 authored by Gabriel Zachmann's avatar Gabriel Zachmann
Browse files

fix some code issues

parent a42a8a1e
No related branches found
No related tags found
No related merge requests found
package main package main
import ( import (
"github.com/jmoiron/sqlx"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/oidc-mytoken/server/internal/config" "github.com/oidc-mytoken/server/internal/config"
...@@ -11,21 +12,24 @@ import ( ...@@ -11,21 +12,24 @@ import (
func main() { func main() {
config.Load() config.Load()
loggerUtils.Init() loggerUtils.Init()
if err := db.Connect(); err != nil { db.Connect()
log.WithError(err).Fatal()
}
deleteExpiredTransferCodes() deleteExpiredTransferCodes()
deleteExpiredAuthInfo() deleteExpiredAuthInfo()
} }
func deleteExpiredTransferCodes() { func execSimpleQuery(sql string) {
if _, err := db.DB().Exec(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`); err != nil { if err := db.RunWithinTransaction(nil, func(tx *sqlx.Tx) error {
_, err := tx.Exec(sql)
return err
}); err != nil {
log.WithError(err).Error() log.WithError(err).Error()
} }
} }
func deleteExpiredTransferCodes() {
execSimpleQuery(`DELETE FROM ProxyTokens WHERE id = ANY(SELECT id FROM TransferCodesAttributes WHERE expires_at < CURRENT_TIMESTAMP())`)
}
func deleteExpiredAuthInfo() { func deleteExpiredAuthInfo() {
if _, err := db.DB().Exec(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`); err != nil { execSimpleQuery(`DELETE FROM AuthInfo WHERE expires_at < CURRENT_TIMESTAMP()`)
log.WithError(err).Error()
}
} }
...@@ -14,7 +14,7 @@ import ( ...@@ -14,7 +14,7 @@ import (
"github.com/oidc-mytoken/server/shared/utils/issuerUtils" "github.com/oidc-mytoken/server/shared/utils/issuerUtils"
) )
var defaultConfig = config{ var defaultConfig = Config{
Server: serverConf{ Server: serverConf{
Port: 8000, Port: 8000,
TLS: tlsConf{ TLS: tlsConf{
...@@ -70,8 +70,8 @@ var defaultConfig = config{ ...@@ -70,8 +70,8 @@ var defaultConfig = config{
}, },
} }
// config holds the server configuration // Config holds the server configuration
type config struct { type Config struct {
IssuerURL string `yaml:"issuer"` IssuerURL string `yaml:"issuer"`
Server serverConf `yaml:"server"` Server serverConf `yaml:"server"`
GeoIPDBFile string `yaml:"geo_ip_db_file"` GeoIPDBFile string `yaml:"geo_ip_db_file"`
...@@ -166,10 +166,10 @@ type ProviderConf struct { ...@@ -166,10 +166,10 @@ type ProviderConf struct {
AudienceRequestParameter string `yaml:"audience_request_parameter"` AudienceRequestParameter string `yaml:"audience_request_parameter"`
} }
var conf *config var conf *Config
// Get returns the config // Get returns the Config
func Get() *config { func Get() *Config {
return conf return conf
} }
...@@ -240,7 +240,7 @@ var possibleConfigLocations = []string{ ...@@ -240,7 +240,7 @@ var possibleConfigLocations = []string{
"/etc/mytoken", "/etc/mytoken",
} }
// Load reads the config file and populates the config struct; then validates the config // Load reads the config file and populates the Config struct; then validates the Config
func Load() { func Load() {
load() load()
if err := validate(); err != nil { if err := validate(); err != nil {
...@@ -258,7 +258,7 @@ func load() { ...@@ -258,7 +258,7 @@ func load() {
} }
} }
// LoadForSetup reads the config file and populates the config struct; it does not validate the config, since this is not required for setup // LoadForSetup reads the config file and populates the Config struct; it does not validate the Config, since this is not required for setup
func LoadForSetup() { func LoadForSetup() {
load() load()
} }
...@@ -5,14 +5,14 @@ import ( ...@@ -5,14 +5,14 @@ import (
"github.com/oidc-mytoken/server/shared/utils" "github.com/oidc-mytoken/server/shared/utils"
) )
type webCapability struct { type WebCapability struct {
capabilities.Capability capabilities.Capability
intClass *int intClass *int
} }
func WebCapabilities(cc capabilities.Capabilities) (wc []webCapability) { func WebCapabilities(cc capabilities.Capabilities) (wc []WebCapability) {
for _, c := range cc { for _, c := range cc {
wc = append(wc, webCapability{c, nil}) wc = append(wc, WebCapability{c, nil})
} }
return return
} }
...@@ -28,7 +28,7 @@ var normalCapabilities = []string{"AT", "create_super_token", "tokeninfo_introsp ...@@ -28,7 +28,7 @@ var normalCapabilities = []string{"AT", "create_super_token", "tokeninfo_introsp
var warningCapabilities = []string{"list_super_tokens"} var warningCapabilities = []string{"list_super_tokens"}
var dangerCapabilities = []string{"settings"} var dangerCapabilities = []string{"settings"}
func (c webCapability) getIntClass() int { func (c WebCapability) getIntClass() int {
if c.intClass != nil { if c.intClass != nil {
return *c.intClass return *c.intClass
} }
...@@ -47,15 +47,15 @@ func (c webCapability) getIntClass() int { ...@@ -47,15 +47,15 @@ func (c webCapability) getIntClass() int {
return -1 return -1
} }
func (c webCapability) getDangerLevel() int { func (c WebCapability) getDangerLevel() int {
return c.getIntClass() return c.getIntClass()
} }
func (c webCapability) ColorClass() string { func (c WebCapability) ColorClass() string {
return textColorByDanger(c.getDangerLevel()) return textColorByDanger(c.getDangerLevel())
} }
func (c webCapability) CapabilityLevel() string { func (c WebCapability) CapabilityLevel() string {
intClass := c.getIntClass() intClass := c.getIntClass()
switch intClass { switch intClass {
case 0: case 0:
...@@ -68,6 +68,6 @@ func (c webCapability) CapabilityLevel() string { ...@@ -68,6 +68,6 @@ func (c webCapability) CapabilityLevel() string {
return "" return ""
} }
func (c webCapability) IsCreateST() bool { func (c WebCapability) IsCreateST() bool {
return c.Name == capabilities.CapabilityCreateST.Name return c.Name == capabilities.CapabilityCreateST.Name
} }
...@@ -18,10 +18,7 @@ type AuthCodeFlowRequest struct { ...@@ -18,10 +18,7 @@ type AuthCodeFlowRequest struct {
// Native checks if the request is native // Native checks if the request is native
func (r *AuthCodeFlowRequest) Native() bool { func (r *AuthCodeFlowRequest) Native() bool {
if r.RedirectType == redirectTypeNative { return r.RedirectType == redirectTypeNative
return true
}
return false
} }
// UnmarshalJSON implements the json unmarshaler interface // UnmarshalJSON implements the json unmarshaler interface
......
...@@ -78,7 +78,6 @@ func (f OIDCFlow) AddToSliceIfNotFound(s *[]OIDCFlow) { ...@@ -78,7 +78,6 @@ func (f OIDCFlow) AddToSliceIfNotFound(s *[]OIDCFlow) {
return return
} }
*s = append(*s, f) *s = append(*s, f)
return
} }
// OIDCFlowIsInSlice checks if a OIDCFlow is present in a slice of OIDCFlows // OIDCFlowIsInSlice checks if a OIDCFlow is present in a slice of OIDCFlows
......
...@@ -141,8 +141,7 @@ func (my *Mytoken) PollOnce(pollingCode string) (string, bool, error) { ...@@ -141,8 +141,7 @@ func (my *Mytoken) PollOnce(pollingCode string) (string, bool, error) {
} }
var myErr *MytokenError var myErr *MytokenError
if errors.As(err, &myErr) { if errors.As(err, &myErr) {
switch myErr.err { if myErr.err == model.ErrorAuthorizationPending {
case model.ErrorAuthorizationPending:
err = nil err = nil
} }
} }
......
...@@ -50,14 +50,14 @@ func GetAudiencesFromJWT(token string) ([]string, bool) { ...@@ -50,14 +50,14 @@ func GetAudiencesFromJWT(token string) ([]string, bool) {
return nil, false return nil, false
} }
auds := res[0].Value auds := res[0].Value
switch auds.(type) { switch v := auds.(type) {
case string: case string:
return []string{auds.(string)}, true return []string{v}, true
case []string: case []string:
return auds.([]string), true return v, true
case []interface{}: case []interface{}:
strs := []string{} strs := []string{}
for _, s := range auds.([]interface{}) { for _, s := range v {
str, ok := s.(string) str, ok := s.(string)
if !ok { if !ok {
return nil, false return nil, false
......
...@@ -111,11 +111,10 @@ func IPIsIn(ip string, ips []string) bool { ...@@ -111,11 +111,10 @@ func IPIsIn(ip string, ips []string) bool {
if ipNetB != nil && ipNetB.Contains(ipA) { if ipNetB != nil && ipNetB.Contains(ipA) {
return true return true
} }
} else { } else if ip == ipp {
if ip == ipp { return true
return true
}
} }
} }
return false return false
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment