diff --git a/db/Store.go b/db/Store.go index 2f1fe2dd..2f5ea5f8 100644 --- a/db/Store.go +++ b/db/Store.go @@ -3,6 +3,7 @@ package db import ( "errors" log "github.com/Sirupsen/logrus" + "reflect" "time" ) @@ -28,12 +29,13 @@ type RetrieveQueryParams struct { type ObjectScope int type ObjectProperties struct { - TableName string - IsGlobal bool // doesn't belong to other table, for example to project or user. - ForeignColumnName string - PrimaryColumnName string - SortableColumns []string - SortInverted bool + TableName string + IsGlobal bool // doesn't belong to other table, for example to project or user. + ForeignColumnSuffix string + PrimaryColumnName string + SortableColumns []string + SortInverted bool + Type reflect.Type } var ErrNotFound = errors.New("no rows in result set") @@ -43,7 +45,7 @@ func ValidateUsername(login string) error { return nil } -type Transaction interface {} +type Transaction interface{} type Store interface { Connect() error @@ -238,70 +240,73 @@ func getEventUsername(d Store, evt Event) (username string, err error) { } var AccessKeyProps = ObjectProperties{ - TableName: "access_key", - SortableColumns: []string{"name", "type"}, - PrimaryColumnName: "id", -} - -var GlobalAccessKeyProps = ObjectProperties{ - IsGlobal: true, - TableName: "access_key", - SortableColumns: []string{"name", "type"}, - ForeignColumnName: "ssh_key_id", - PrimaryColumnName: "id", + TableName: "access_key", + SortableColumns: []string{"name", "type"}, + ForeignColumnSuffix: "key_id", + PrimaryColumnName: "id", + Type: reflect.TypeOf(AccessKey{}), } var EnvironmentProps = ObjectProperties{ - TableName: "project__environment", - SortableColumns: []string{"name"}, - ForeignColumnName: "environment_id", - PrimaryColumnName: "id", + TableName: "project__environment", + SortableColumns: []string{"name"}, + ForeignColumnSuffix: "environment_id", + PrimaryColumnName: "id", + Type: reflect.TypeOf(Environment{}), } var InventoryProps = ObjectProperties{ - TableName: "project__inventory", - SortableColumns: []string{"name"}, - ForeignColumnName: "inventory_id", - PrimaryColumnName: "id", + TableName: "project__inventory", + SortableColumns: []string{"name"}, + ForeignColumnSuffix: "inventory_id", + PrimaryColumnName: "id", + Type: reflect.TypeOf(Inventory{}), } var RepositoryProps = ObjectProperties{ - TableName: "project__repository", - ForeignColumnName: "repository_id", - PrimaryColumnName: "id", + TableName: "project__repository", + ForeignColumnSuffix: "repository_id", + PrimaryColumnName: "id", + Type: reflect.TypeOf(Repository{}), } var TemplateProps = ObjectProperties{ TableName: "project__template", SortableColumns: []string{"name"}, PrimaryColumnName: "id", + Type: reflect.TypeOf(Template{}), } var ScheduleProps = ObjectProperties{ TableName: "project__schedule", PrimaryColumnName: "id", + Type: reflect.TypeOf(Schedule{}), } var ProjectUserProps = ObjectProperties{ TableName: "project__user", PrimaryColumnName: "user_id", + Type: reflect.TypeOf(ProjectUser{}), } var ProjectProps = ObjectProperties{ TableName: "project", IsGlobal: true, PrimaryColumnName: "id", + Type: reflect.TypeOf(Project{}), } var UserProps = ObjectProperties{ TableName: "user", IsGlobal: true, PrimaryColumnName: "id", + Type: reflect.TypeOf(User{}), } var SessionProps = ObjectProperties{ TableName: "session", PrimaryColumnName: "id", + Type: reflect.TypeOf(Session{}), } var TokenProps = ObjectProperties{ @@ -314,8 +319,10 @@ var TaskProps = ObjectProperties{ IsGlobal: true, PrimaryColumnName: "id", SortInverted: true, + Type: reflect.TypeOf(Task{}), } var TaskOutputProps = ObjectProperties{ - TableName: "task__output", + TableName: "task__output", + Type: reflect.TypeOf(TaskOutput{}), } diff --git a/db/bolt/BoltDb.go b/db/bolt/BoltDb.go index f39d4bcb..01994f78 100644 --- a/db/bolt/BoltDb.go +++ b/db/bolt/BoltDb.go @@ -9,6 +9,7 @@ import ( "go.etcd.io/bbolt" "reflect" "sort" + "strings" ) const MaxID = 2147483647 @@ -109,12 +110,12 @@ func (d *BoltDb) getObject(bucketID int, props db.ObjectProperties, objectID obj return } -// getFieldNameByTag tries to find field by tag name and value in provided type. +// getFieldNameByTagSuffix tries to find field by tag name and value in provided type. // It returns error if field not found. -func getFieldNameByTag(t reflect.Type, tagName string, tagValue string) (string, error) { +func getFieldNameByTagSuffix(t reflect.Type, tagName string, tagValueSuffix string) (string, error) { n := t.NumField() for i := 0; i < n; i++ { - if t.Field(i).Tag.Get(tagName) == tagValue { + if strings.HasSuffix(t.Field(i).Tag.Get(tagName), tagValueSuffix) { return t.Field(i).Name, nil } } @@ -122,7 +123,7 @@ func getFieldNameByTag(t reflect.Type, tagName string, tagValue string) (string, if t.Field(i).Tag != "" || t.Field(i).Type.Kind() != reflect.Struct { continue } - str, err := getFieldNameByTag(t.Field(i).Type, tagName, tagValue) + str, err := getFieldNameByTagSuffix(t.Field(i).Type, tagName, tagValueSuffix) if err == nil { return str, nil } @@ -134,7 +135,7 @@ func sortObjects(objects interface{}, sortBy string, sortInverted bool) error { objectsValue := reflect.ValueOf(objects).Elem() objType := objectsValue.Type().Elem() - fieldName, err := getFieldNameByTag(objType, "db", sortBy) + fieldName, err := getFieldNameByTagSuffix(objType, "db", sortBy) if err != nil { return err } @@ -315,63 +316,68 @@ func (d *BoltDb) getObjects(bucketID int, props db.ObjectProperties, params db.R }) } -func (d *BoltDb) isObjectInUse(bucketID int, props db.ObjectProperties, objID objectID, userProps db.ObjectProperties) (inUse bool, err error) { - var templates []db.Template +func isObjectBelongTo(props db.ObjectProperties, objID objectID, tpl interface{}) bool { + if props.ForeignColumnSuffix == "" { + return false + } - err = d.getObjects(bucketID, userProps, db.RetrieveQueryParams{}, func (tpl interface{}) bool { - if props.ForeignColumnName == "" { + fieldName, err := getFieldNameByTagSuffix(reflect.TypeOf(tpl), "db", props.ForeignColumnSuffix) + + if err != nil { + return false + } + + f := reflect.ValueOf(tpl).FieldByName(fieldName) + + if f.IsZero() { + return false + } + + if f.Kind() == reflect.Ptr { + if f.IsNil() { return false } - fieldName, err := getFieldNameByTag(reflect.TypeOf(tpl), "db", props.ForeignColumnName) + f = f.Elem() + } - if err != nil { - return false - } + var fVal objectID + switch f.Kind() { + case reflect.Int, + reflect.Int8, + reflect.Int16, + reflect.Int32, + reflect.Int64, + reflect.Uint, + reflect.Uint8, + reflect.Uint16, + reflect.Uint32, + reflect.Uint64: + fVal = intObjectID(f.Int()) + case reflect.String: + fVal = strObjectID(f.String()) + } - f := reflect.ValueOf(tpl).FieldByName(fieldName) + if fVal == nil { + return false + } - if f.IsZero() { - return false - } + return bytes.Equal(fVal.ToBytes(), objID.ToBytes()) +} - if f.Kind() == reflect.Ptr { - if f.IsNil() { - return false - } +// isObjectInUse checks if objID associated with any object in foreignTableProps. +func (d *BoltDb) isObjectInUse(bucketID int, objProps db.ObjectProperties, objID objectID, foreignTableProps db.ObjectProperties) (inUse bool, err error) { + templates := reflect.New(reflect.SliceOf(foreignTableProps.Type)) - f = f.Elem() - } - - var fVal objectID - switch f.Kind() { - case reflect.Int, - reflect.Int8, - reflect.Int16, - reflect.Int32, - reflect.Int64, - reflect.Uint, - reflect.Uint8, - reflect.Uint16, - reflect.Uint32, - reflect.Uint64: - fVal = intObjectID(f.Int()) - case reflect.String: - fVal = strObjectID(f.String()) - } - - if fVal == nil { - return false - } - - return bytes.Equal(fVal.ToBytes(), objID.ToBytes()) - }, &templates) + err = d.getObjects(bucketID, foreignTableProps, db.RetrieveQueryParams{}, func (foreignObj interface{}) bool { + return isObjectBelongTo(objProps, objID, foreignObj) + }, templates.Interface()) if err != nil { return } - inUse = len(templates) > 0 + inUse = templates.Elem().Len() > 0 return } @@ -451,7 +457,7 @@ func (d *BoltDb) updateObject(bucketID int, props db.ObjectProperties, object in return db.ErrNotFound } - idFieldName, err := getFieldNameByTag(reflect.TypeOf(object), "db", props.PrimaryColumnName) + idFieldName, err := getFieldNameByTagSuffix(reflect.TypeOf(object), "db", props.PrimaryColumnName) if err != nil { return err @@ -510,7 +516,7 @@ func (d *BoltDb) createObject(bucketID int, props db.ObjectProperties, object in var objectID objectID if props.PrimaryColumnName != "" { - idFieldName, err := getFieldNameByTag(reflect.TypeOf(object), "db", props.PrimaryColumnName) + idFieldName, err := getFieldNameByTagSuffix(reflect.TypeOf(object), "db", props.PrimaryColumnName) if err != nil { return err diff --git a/db/bolt/BoltDb_test.go b/db/bolt/BoltDb_test.go index 9703c521..1ac78deb 100644 --- a/db/bolt/BoltDb_test.go +++ b/db/bolt/BoltDb_test.go @@ -190,7 +190,7 @@ func TestSortObjects(t *testing.T) { } func TestGetFieldNameByTag(t *testing.T) { - f, err := getFieldNameByTag(reflect.TypeOf(test1{}), "db", "first_name") + f, err := getFieldNameByTagSuffix(reflect.TypeOf(test1{}), "db", "first_name") if err != nil { t.Fatal(err.Error()) } @@ -201,7 +201,7 @@ func TestGetFieldNameByTag(t *testing.T) { } func TestGetFieldNameByTag2(t *testing.T) { - f, err := getFieldNameByTag(reflect.TypeOf(db.UserWithPwd{}), "db", "id") + f, err := getFieldNameByTagSuffix(reflect.TypeOf(db.UserWithPwd{}), "db", "id") if err != nil { t.Fatal(err.Error()) } diff --git a/db/sql/SqlDb.go b/db/sql/SqlDb.go index 7db29517..0d24567e 100644 --- a/db/sql/SqlDb.go +++ b/db/sql/SqlDb.go @@ -235,12 +235,12 @@ func (d *SqlDb) getObjects(projectID int, props db.ObjectProperties, params db.R } func (d *SqlDb) isObjectInUse(projectID int, props db.ObjectProperties, objectID int) (bool, error) { - if props.ForeignColumnName == "" { + if props.ForeignColumnSuffix == "" { return false, nil } templatesC, err := d.sql.SelectInt( - "select count(1) from project__template where project_id=? and " + props.ForeignColumnName+ "=?", + "select count(1) from project__template where project_id=? and " + props.ForeignColumnSuffix+ "=?", projectID, objectID)