Skip to content
Snippets Groups Projects
Commit d13f0a4b authored by Gabriel Zachmann's avatar Gabriel Zachmann
Browse files

add db and user setup

parent b7eabca5
No related branches found
No related tags found
No related merge requests found
......@@ -75,15 +75,15 @@ var app = &cli.App{
Name: "password",
Aliases: []string{"p"},
Usage: "The password for connecting to the database",
EnvVars: []string{"DB_PASSWORD"},
EnvVars: []string{"DB_ROOT_PASSWORD", "DB_ROOT_PW"},
Destination: &dbConfig.Password,
Placeholder: "PASSWORD",
},
&cli.StringFlag{
Name: "password-file",
Aliases: []string{"pass-file"},
Aliases: []string{"pw-file"},
Usage: "Read the password for connecting to the database from this file",
EnvVars: []string{"DB_PASSWORD_FILE"},
EnvVars: []string{"DB_PASSWORD_FILE", "DB_PW_FILE"},
Destination: &dbConfig.PasswordFile,
Placeholder: "FILE",
},
......@@ -95,6 +95,7 @@ var app = &cli.App{
Value: cli.NewStringSlice("localhost"),
Destination: &dbConfig.Hosts,
Placeholder: "HOST",
TakesFile: true,
},
},
Action: func(context *cli.Context) error {
......@@ -108,8 +109,8 @@ var app = &cli.App{
} else if !force {
return fmt.Errorf("No mytoken servers specified. Please provide mytoken servers or use '-f' to force database migration.")
}
if dbConfig.Password == "" {
dbConfig.Password = prompter.Password("Enter password")
if dbConfig.GetPassword() == "" {
dbConfig.Password = prompter.Password(fmt.Sprintf("Enter db password for user '%s'", dbConfig.User))
}
dbConfig.ReconnectInterval = 60
dbConfig.DBConf.Hosts = dbConfig.Hosts.Value()
......
package main
import (
"fmt"
"os"
"os/exec"
"strings"
"time"
......@@ -15,6 +12,7 @@ import (
"github.com/oidc-mytoken/server/internal/db/dbmigrate"
"github.com/oidc-mytoken/server/internal/db/dbrepo/versionrepo"
"github.com/oidc-mytoken/server/internal/model/version"
"github.com/oidc-mytoken/server/internal/utils/dbcl"
)
func did(state versionrepo.DBVersionState, version string) (beforeDone, afterDone bool) {
......@@ -105,7 +103,7 @@ func updateCallback(tx *sqlx.Tx, cmds, version string, done map[string]bool, dbU
return nil
}
return db.RunWithinTransaction(tx, func(tx *sqlx.Tx) error {
if err := runDBCommands(cmds); err != nil {
if err := dbcl.RunDBCommands(cmds, dbConfig.DBConf, true); err != nil {
return err
}
return dbUpdateCallback(tx, version)
......@@ -141,20 +139,3 @@ func getVersionForNode(node string) (string, error) {
}
return my.Version, nil
}
func runDBCommands(cmds string) error {
cmd := exec.Command("sh", "-c", fmt.Sprintf("mysql -uroot -p%s --protocol tcp -h %s %s", dbConfig.GetPassword(), dbConfig.Hosts.Value()[0], dbConfig.DB))
cmdIn, err := cmd.StdinPipe()
if err != nil {
return err
}
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if _, err = cmdIn.Write([]byte(cmds)); err != nil {
return err
}
if err = cmdIn.Close(); err != nil {
return err
}
return cmd.Run()
}
PREPARE createDB FROM CONCAT('CREATE DATABASE IF NOT EXISTS ', @DB);
EXECUTE createDB;
DEALLOCATE PREPARE createDB;
SET @DB_TABLES = CONCAT(@DB, '.*');
PREPARE createUser FROM CONCAT('CREATE OR REPLACE USER ', @USER, ' IDENTIFIED BY "', @PASSWORD, '"');
PREPARE grantRights FROM CONCAT('GRANT Execute, Select, Show view, Insert, Update, Delete ON ', @DB_TABLES, ' TO ',
@USER);
EXECUTE createUser;
EXECUTE grantRights;
FLUSH PRIVILEGES;
DEALLOCATE PREPARE createUser;
DEALLOCATE PREPARE grantRights;
PREPARE setDB FROM 'SET @DB = ?';
PREPARE setUser FROM 'SET @USER = ?';
PREPARE setPassword FROM 'SET @PASSWORD = ?';
package main
import (
"embed"
"fmt"
"io/ioutil"
"os"
......@@ -13,11 +14,59 @@ import (
"github.com/oidc-mytoken/server/internal/config"
"github.com/oidc-mytoken/server/internal/jws"
"github.com/oidc-mytoken/server/internal/model/version"
"github.com/oidc-mytoken/server/internal/utils/dbcl"
loggerUtils "github.com/oidc-mytoken/server/internal/utils/logger"
"github.com/oidc-mytoken/server/internal/utils/zipdownload"
"github.com/oidc-mytoken/server/shared/utils/fileutil"
)
type _rootDBCredentials struct {
User string
Password string
PasswordFile string
}
var rootDBCredentials _rootDBCredentials
func (cred _rootDBCredentials) ToDBConf() config.DBConf {
return config.DBConf{
Hosts: config.Get().DB.Hosts,
User: cred.User,
Password: cred.Password,
PasswordFile: cred.PasswordFile,
ReconnectInterval: config.Get().DB.ReconnectInterval,
}
}
var dbFlags = []cli.Flag{
&cli.StringFlag{
Name: "user",
Aliases: []string{"u", "root-user", "db-user"},
Usage: "The username for the (root) user used for setting up the db",
EnvVars: []string{"DB_USER"},
Value: "root",
Destination: &rootDBCredentials.User,
Placeholder: "ROOT",
},
&cli.StringFlag{
Name: "password",
Aliases: []string{"p", "pw", "db-password", "db-pw"},
Usage: "The password for the (root) user used for setting up the db",
EnvVars: []string{"DB_PW", "DB_PASSWORD"},
Destination: &rootDBCredentials.Password,
Placeholder: "PASSWORD",
},
&cli.StringFlag{
Name: "password-file",
Aliases: []string{"pw-file"},
Usage: "Read the password for connecting to the database from this file",
EnvVars: []string{"DB_PASSWORD_FILE", "DB_PW_FILE"},
Destination: &rootDBCredentials.PasswordFile,
TakesFile: true,
Placeholder: "FILE",
},
}
var app = &cli.App{
Name: "mytoken-setup",
Usage: "Command line client for easily setting up a mytoken server",
......@@ -49,6 +98,26 @@ var app = &cli.App{
},
},
},
&cli.Command{
Name: "db",
Usage: "Setups for the database",
Flags: append([]cli.Flag{}, dbFlags...),
Subcommands: cli.Commands{
&cli.Command{
Name: "db",
Aliases: []string{"database"},
Usage: "Creates the database in the database server",
Action: createDB,
Flags: append([]cli.Flag{}, dbFlags...),
},
&cli.Command{
Name: "user",
Usage: "Creates the normal database user",
Action: createUser,
Flags: append([]cli.Flag{}, dbFlags...),
},
},
},
},
}
......@@ -94,3 +163,66 @@ func createSigningKey(_ *cli.Context) error {
fmt.Printf("Wrote key to file '%s'.\n", filepath)
return nil
}
//go:embed scripts
var sqlScripts embed.FS
func readSQLFile(path string) (string, error) {
data, err := sqlScripts.ReadFile(path)
if err != nil {
return "", err
}
return string(data), nil
}
func _getSetVars() (string, error) {
return readSQLFile("scripts/vars.sql")
}
func getSetVarsCommands(db, user, password string) (string, error) {
cmds, err := _getSetVars()
if err != nil {
return "", err
}
if db != "" {
cmds += fmt.Sprintf(`EXECUTE setDB USING '%s';\n`, db)
}
if user != "" {
cmds += fmt.Sprintf(`EXECUTE setUser USING '%s';\n`, user)
}
if password != "" {
cmds += fmt.Sprintf(`EXECUTE setPassword USING '%s';\n`, password)
}
return cmds, nil
}
func getDBCmds() (string, error) {
return readSQLFile("scripts/db.sql")
}
func getUserCmds() (string, error) {
return readSQLFile("scripts/user.sql")
}
func createDB(_ *cli.Context) error {
cmds, err := getSetVarsCommands(config.Get().DB.DB, config.Get().DB.User, config.Get().DB.GetPassword())
if err != nil {
return err
}
dbCmds, err := getDBCmds()
if err != nil {
return err
}
cmds += dbCmds
return dbcl.RunDBCommands(cmds, rootDBCredentials.ToDBConf(), true)
}
func createUser(_ *cli.Context) error {
cmds, err := getSetVarsCommands(config.Get().DB.DB, config.Get().DB.User, config.Get().DB.GetPassword())
if err != nil {
return err
}
userCmds, err := getUserCmds()
if err != nil {
return err
}
cmds += userCmds
return dbcl.RunDBCommands(cmds, rootDBCredentials.ToDBConf(), true)
}
......@@ -37,10 +37,8 @@ var defaultConfig = Config{
},
},
DB: DBConf{
Hosts: []string{"localhost"},
User: "mytoken",
// The default value for Password is "mytoken", but it is not set here, but returned in the GetPassword function,
// because the default value is only used if no password and no password file are provided
Hosts: []string{"localhost"},
User: "mytoken",
DB: "mytoken",
ReconnectInterval: 60,
},
......@@ -229,7 +227,7 @@ func (conf *DBConf) GetPassword() string {
return conf.Password
}
if conf.PasswordFile == "" {
return "mytoken"
return ""
}
content, err := ioutil.ReadFile(conf.PasswordFile)
if err != nil {
......
package dbcl
import (
"fmt"
"os"
"os/exec"
"github.com/pkg/errors"
"github.com/oidc-mytoken/server/internal/config"
)
// RunDBCommands executes SQL stmts through the mysql cli
func RunDBCommands(cmds string, dbConfig config.DBConf, printOutput bool) error {
mysqlCmd := fmt.Sprintf("mysql -u%s -p%s --protocol tcp -h %s",
dbConfig.User, dbConfig.GetPassword(), dbConfig.Hosts[0])
if dbConfig.DB != "" {
mysqlCmd += fmt.Sprintf(" %s", dbConfig.DB)
}
cmd := exec.Command("sh", "-c", mysqlCmd)
cmdIn, err := cmd.StdinPipe()
if err != nil {
return errors.WithStack(err)
}
if printOutput {
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
}
if _, err = cmdIn.Write([]byte(cmds)); err != nil {
return errors.WithStack(err)
}
if err = cmdIn.Close(); err != nil {
return errors.WithStack(err)
}
return cmd.Run()
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment