diff --git a/db/Store.go b/db/Store.go index f8fd4fad..665e1077 100644 --- a/db/Store.go +++ b/db/Store.go @@ -257,6 +257,7 @@ var SessionProps = ObjectProperties{ var TokenProps = ObjectProperties{ TableName: "user__token", PrimaryColumnName: "id", + Type: reflect.TypeOf(APIToken{}), } var TaskProps = ObjectProperties{ diff --git a/db/bolt/BoltDb.go b/db/bolt/BoltDb.go index 2465855a..bcdf16ea 100644 --- a/db/bolt/BoltDb.go +++ b/db/bolt/BoltDb.go @@ -426,13 +426,13 @@ func (d *BoltDb) deleteObjectSoft(bucketID int, props db.ObjectProperties, objec return db.ErrNotFound } - d := b.Get(objectID.ToBytes()) + j := b.Get(objectID.ToBytes()) - if d == nil { + if j == nil { return db.ErrNotFound } - return json.Unmarshal(d, &data) + return json.Unmarshal(j, &data) }) if err != nil { @@ -479,7 +479,7 @@ func (d *BoltDb) updateObject(bucketID int, props db.ObjectProperties, object in idValue := reflect.ValueOf(object).FieldByName(idFieldName) - var objectID objectID + var objID objectID switch idValue.Kind() { case reflect.Int, @@ -492,16 +492,16 @@ func (d *BoltDb) updateObject(bucketID int, props db.ObjectProperties, object in reflect.Uint16, reflect.Uint32, reflect.Uint64: - objectID = intObjectID(idValue.Int()) + objID = intObjectID(idValue.Int()) case reflect.String: - objectID = strObjectID(idValue.String()) + objID = strObjectID(idValue.String()) } - if objectID == nil { + if objID == nil { return fmt.Errorf("unsupported ID type") } - if b.Get(objectID.ToBytes()) == nil { + if b.Get(objID.ToBytes()) == nil { return db.ErrNotFound } @@ -510,7 +510,7 @@ func (d *BoltDb) updateObject(bucketID int, props db.ObjectProperties, object in return err } - return b.Put(objectID.ToBytes(), str) + return b.Put(objID.ToBytes(), str) }) } @@ -527,13 +527,13 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in tmpObj := reflect.New(objPtr.Elem().Type()).Elem() tmpObj.Set(objPtr.Elem()) - var objectID objectID + var objID objectID if props.PrimaryColumnName != "" { - idFieldName, err := getFieldNameByTagSuffix(reflect.TypeOf(object), "db", props.PrimaryColumnName) + idFieldName, err2 := getFieldNameByTagSuffix(reflect.TypeOf(object), "db", props.PrimaryColumnName) - if err != nil { - return err + if err2 != nil { + return err2 } idValue := tmpObj.FieldByName(idFieldName) @@ -550,9 +550,9 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in reflect.Uint32, reflect.Uint64: if idValue.Int() == 0 { - id, err2 := b.NextSequence() - if err2 != nil { - return err2 + id, err3 := b.NextSequence() + if err3 != nil { + return err3 } if props.SortInverted { id = MaxID - id @@ -560,18 +560,18 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in idValue.SetInt(int64(id)) } - objectID = intObjectID(idValue.Int()) + objID = intObjectID(idValue.Int()) case reflect.String: if idValue.String() == "" { return fmt.Errorf("object ID can not be empty string") } - objectID = strObjectID(idValue.String()) + objID = strObjectID(idValue.String()) case reflect.Invalid: - id, err2 := b.NextSequence() - if err2 != nil { - return err2 + id, err3 := b.NextSequence() + if err3 != nil { + return err3 } - objectID = intObjectID(id) + objID = intObjectID(id) default: return fmt.Errorf("unsupported ID type") } @@ -583,10 +583,10 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in if props.SortInverted { id = MaxID - id } - objectID = intObjectID(id) + objID = intObjectID(id) } - if objectID == nil { + if objID == nil { return fmt.Errorf("object ID can not be nil") } @@ -596,7 +596,7 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in return err } - return b.Put(objectID.ToBytes(), str) + return b.Put(objID.ToBytes(), str) }) return object, err diff --git a/db/bolt/BoltDb_test.go b/db/bolt/BoltDb_test.go index f61a3066..63c40217 100644 --- a/db/bolt/BoltDb_test.go +++ b/db/bolt/BoltDb_test.go @@ -266,7 +266,7 @@ func TestIsObjectInUse_EnvironmentNil(t *testing.T) { }) if err != nil { - t.Fatal(err.Error()) + t.Fatal(err) } _, err = store.CreateTemplate(db.Template{ @@ -277,17 +277,82 @@ func TestIsObjectInUse_EnvironmentNil(t *testing.T) { }) if err != nil { - t.Fatal(err.Error()) + t.Fatal(err) } isUse, err := store.isObjectInUse(proj.ID, db.EnvironmentProps, intObjectID(10), db.TemplateProps) if err != nil { - t.Fatal(err.Error()) + t.Fatal(err) } if isUse { t.Fatal() } - +} + +func TestBoltDb_CreateAPIToken(t *testing.T) { + store := CreateTestStore() + + user, err := store.CreateUser(db.UserWithPwd{ + Pwd: "3412341234123", + User: db.User{ + Username: "test", + Name: "Test", + Email: "test@example.com", + Admin: true, + }, + }) + if err != nil { + t.Fatal(err) + } + + token, err := store.CreateAPIToken(db.APIToken{ + ID: "f349gyhgqirgysfgsfg34973dsfad", + UserID: user.ID, + }) + if err != nil { + t.Fatal(err) + } + + token2, err := store.GetAPIToken(token.ID) + if err != nil { + t.Fatal(err) + } + + if token2.ID != token.ID { + t.Fatal() + } + + tokens, err := store.GetAPITokens(user.ID) + if err != nil { + t.Fatal(err) + } + + if len(tokens) != 1 { + t.Fatal() + } + + if tokens[0].ID != token.ID { + t.Fatal() + } + + err = store.ExpireAPIToken(user.ID, token.ID) + if err != nil { + t.Fatal(err) + } + + tokens, err = store.GetAPITokens(user.ID) + if err != nil { + t.Fatal(err) + } + + token2, err = store.GetAPIToken(token.ID) + if err != nil { + t.Fatal(err) + } + + if !token2.Expired { + t.Fatal() + } } diff --git a/db/bolt/session.go b/db/bolt/session.go index 9012c0ad..2eb8d2f3 100644 --- a/db/bolt/session.go +++ b/db/bolt/session.go @@ -2,18 +2,22 @@ package bolt import ( "github.com/ansible-semaphore/semaphore/db" + "reflect" "time" ) -var globalTokenObject = db.ObjectProperties{ - TableName: "token", -} - type globalToken struct { ID string `db:"id" json:"id"` UserID int `db:"user_id" json:"user_id"` } +var globalTokenObject = db.ObjectProperties{ + TableName: "token", + PrimaryColumnName: "id", + Type: reflect.TypeOf(globalToken{}), + IsGlobal: true, +} + func (d *BoltDb) CreateSession(session db.Session) (db.Session, error) { newSession, err := d.createObject(session.UserID, db.SessionProps, session) if err != nil { @@ -87,6 +91,6 @@ func (d *BoltDb) TouchSession(userID int, sessionID int) (err error) { } func (d *BoltDb) GetAPITokens(userID int) (tokens []db.APIToken, err error) { - err = d.getObjects(userID, db.SessionProps, db.RetrieveQueryParams{}, nil, &tokens) + err = d.getObjects(userID, db.TokenProps, db.RetrieveQueryParams{}, nil, &tokens) return }