Skip to content
Snippets Groups Projects
mytoken.go 4.57 KiB
package mytokenrepo

import (
	"encoding/base64"

	"github.com/jmoiron/sqlx"
	"github.com/pkg/errors"
	log "github.com/sirupsen/logrus"

	"github.com/oidc-mytoken/api/v0"

	"github.com/oidc-mytoken/server/internal/db"
	eventService "github.com/oidc-mytoken/server/shared/mytoken/event"
	event "github.com/oidc-mytoken/server/shared/mytoken/event/pkg"
	mytoken "github.com/oidc-mytoken/server/shared/mytoken/pkg"
	"github.com/oidc-mytoken/server/shared/mytoken/pkg/mtid"
	"github.com/oidc-mytoken/server/shared/utils/cryptUtils"
)

// MytokenEntry holds the information of a MytokenEntry as stored in the
// database
type MytokenEntry struct {
	ID                     mtid.MTID
	SeqNo                  uint64
	ParentID               mtid.MTID `db:"parent_id"`
	Token                  *mytoken.Mytoken
	rtID                   *uint64
	refreshToken           string
	encryptionKey          []byte
	rtEncrypted            string
	encryptionKeyEncrypted string
	Name                   string
	IP                     string `db:"ip_created"`
	networkData            api.ClientMetaData
}

// InitRefreshToken links a refresh token to this MytokenEntry
func (mte *MytokenEntry) InitRefreshToken(rt string) error {
	mte.refreshToken = rt
	mte.encryptionKey = cryptUtils.RandomBytes(32)
	tmp, err := cryptUtils.AESEncrypt(mte.refreshToken, mte.encryptionKey)
	if err != nil {
		return err
	}
	mte.rtEncrypted = tmp
	jwt, err := mte.Token.ToJWT()
	if err != nil {
		return err
	}
	tmp, err = cryptUtils.AES256Encrypt(base64.StdEncoding.EncodeToString(mte.encryptionKey), jwt)
	if err != nil {
		return err
	}
	mte.encryptionKeyEncrypted = tmp
	return nil
}

// SetRefreshToken updates the refresh token for this MytokenEntry
func (mte *MytokenEntry) SetRefreshToken(rtID uint64, key []byte) error {
	mte.encryptionKey = key
	jwt, err := mte.Token.ToJWT()
	if err != nil {
		return err
	}
	tmp, err := cryptUtils.AES256Encrypt(base64.StdEncoding.EncodeToString(key), jwt)
	if err != nil {
		return err
	}
	mte.encryptionKeyEncrypted = tmp
	mte.rtID = &rtID
	return nil
}

// NewMytokenEntry creates a new MytokenEntry
func NewMytokenEntry(mt *mytoken.Mytoken, name string, networkData api.ClientMetaData) *MytokenEntry {
	return &MytokenEntry{
		ID:          mt.ID,
		SeqNo:       mt.SeqNo,
		Token:       mt,
		Name:        name,
		IP:          networkData.IP,
		networkData: networkData,
	}
}

// Root checks if this MytokenEntry is a root token
func (mte *MytokenEntry) Root() bool {
	return !mte.ParentID.HashValid()
}

// Store stores the MytokenEntry in the database
func (mte *MytokenEntry) Store(rlog log.Ext1FieldLogger, tx *sqlx.Tx, comment string) error {
	steStore := mytokenEntryStore{
		ID:       mte.ID,
		SeqNo:    mte.SeqNo,
		ParentID: mte.ParentID,
		Name:     db.NewNullString(mte.Name),
		IP:       mte.IP,
		Iss:      mte.Token.OIDCIssuer,
		Sub:      mte.Token.OIDCSubject,
	}
	return db.RunWithinTransaction(
		rlog, tx, func(tx *sqlx.Tx) error {
			if mte.rtID == nil {
				if _, err := tx.Exec(`CALL CryptStoreRT_Insert(?,@ID)`, mte.rtEncrypted); err != nil {
					return errors.WithStack(err)
				}
				var rtID uint64
				if err := tx.Get(&rtID, `SELECT @ID`); err != nil {
					return errors.WithStack(err)
				}
				mte.rtID = &rtID
			}
			steStore.RefreshTokenID = *mte.rtID
			if err := steStore.Store(rlog, tx); err != nil {
				return err
			}
			if err := storeEncryptionKey(tx, mte.encryptionKeyEncrypted, steStore.RefreshTokenID, mte.ID); err != nil {
				return err
			}
			return eventService.LogEvent(
				rlog, tx, eventService.MTEvent{
					Event: event.FromNumber(event.MTCreated, comment),
					MTID:  mte.ID,
				}, mte.networkData,
			)
		},
	)
}

func storeEncryptionKey(tx *sqlx.Tx, key string, rtID uint64, myid mtid.MTID) error {
	_, err := tx.Exec(`CALL EncryptionKeysRT_Insert(?,?,?)`, key, rtID, myid)
	return errors.WithStack(err)
}

type mytokenEntryStore struct {
	ID             mtid.MTID
	SeqNo          uint64
	ParentID       mtid.MTID `db:"parent_id"`
	RootID         mtid.MTID `db:"root_id"`
	RefreshTokenID uint64    `db:"rt_id"`
	Name           db.NullString
	IP             string `db:"ip_created"`
	Iss            string
	Sub            string
}

// Store stores the mytokenEntryStore in the database; if this is the first token for this user, the user is also added
// to the db
func (e *mytokenEntryStore) Store(rlog log.Ext1FieldLogger, tx *sqlx.Tx) error {
	return db.RunWithinTransaction(
		rlog, tx, func(tx *sqlx.Tx) error {
			_, err := tx.Exec(
				`CALL MTokens_Insert(?,?,?,?,?,?,?,?)`,
				e.Sub, e.Iss, e.ID, e.SeqNo, e.ParentID, e.RefreshTokenID, e.Name, e.IP,
			)
			return errors.WithStack(err)
		},
	)
}