fix(be): add quotes to user table in queties to support postgres

This commit is contained in:
Denis Gukov 2021-08-24 22:52:35 +05:00
parent 63a4a32ac1
commit 551ef97233
11 changed files with 81 additions and 66 deletions

View File

@ -98,14 +98,54 @@ func (d *SqlDb) prepareQuery(query string) string {
return d.prepareQueryWithDialect(query, d.sql.Dialect)
}
func (d *SqlDb) insert(primaryKeyColumnName string, query string, args ...interface{}) (int, error) {
var insertId int64
switch d.sql.Dialect.(type) {
case gorp.PostgresDialect:
query += " returning " + primaryKeyColumnName
err := d.sql.QueryRow(d.prepareQuery(query), args...).Scan(&insertId)
if err != nil {
return 0, err
}
default:
res, err := d.exec(query, args...)
if err != nil {
return 0, err
}
insertId, err = res.LastInsertId()
if err != nil {
return 0, err
}
}
return int(insertId), nil
}
func (d *SqlDb) exec(query string, args ...interface{}) (sql.Result, error) {
return d.sql.Exec(d.prepareQuery(query), args...)
q := d.prepareQuery(query)
return d.sql.Exec(q, args...)
}
func (d *SqlDb) selectOne(holder interface{}, query string, args ...interface{}) error {
return d.sql.SelectOne(holder, d.prepareQuery(query), args...)
}
func (d *SqlDb) selectNullStr(query string, args ...interface{}) (sql.NullString, error) {
return d.sql.SelectNullStr(d.prepareQuery(query), args...)
}
func (d *SqlDb) selectAll(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
q := d.prepareQuery(query)
return d.sql.Select(i, q, args...)
}
// prepareMigration converts migration SQLite-query to current dialect.
// Supported MySQL and Postgres dialects.
func (d *SqlDb) prepareMigration(query string) string {
@ -289,7 +329,7 @@ func (d *SqlDb) getObjects(projectID int, props db.ObjectProperties, params db.R
return
}
_, err = d.sql.Select(objects, query, args...)
_, err = d.selectAll(objects, query, args...)
return
}
@ -424,7 +464,7 @@ func getSqlForTable(tableName string, p db.RetrieveQueryParams) (string, []inter
}
q := squirrel.Select("*").
From(tableName)
From("`" + tableName + "`")
if p.SortBy != "" {
sortDirection := "ASC"

View File

@ -28,7 +28,8 @@ func (d *SqlDb) UpdateAccessKey(key db.AccessKey) error {
}
func (d *SqlDb) CreateAccessKey(key db.AccessKey) (newKey db.AccessKey, err error) {
res, err := d.exec(
insertID, err := d.insert(
"id",
"insert into access_key (name, type, project_id, `key`, secret) values (?, ?, ?, ?, ?)",
key.Name,
key.Type,
@ -40,13 +41,8 @@ func (d *SqlDb) CreateAccessKey(key db.AccessKey) (newKey db.AccessKey, err erro
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newKey = key
newKey.ID = int(insertID)
newKey.ID = insertID
return
}
@ -84,7 +80,8 @@ func (d *SqlDb) UpdateGlobalAccessKey(key db.AccessKey) error {
}
func (d *SqlDb) CreateGlobalAccessKey(key db.AccessKey) (newKey db.AccessKey, err error) {
res, err := d.exec(
insertID, err := d.insert(
"id",
"insert into access_key (name, type, `key`, secret) values (?, ?, ?, ?)",
key.Name,
key.Type,
@ -95,13 +92,8 @@ func (d *SqlDb) CreateGlobalAccessKey(key db.AccessKey) (newKey db.AccessKey, er
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newKey = key
newKey.ID = int(insertID)
newKey.ID = insertID
return
}

View File

@ -24,7 +24,8 @@ func (d *SqlDb) UpdateEnvironment(env db.Environment) error {
}
func (d *SqlDb) CreateEnvironment(env db.Environment) (newEnv db.Environment, err error) {
res, err := d.exec(
insertID, err := d.insert(
"id",
"insert into project__environment (project_id, name, json, password) values (?, ?, ?, ?)",
env.ProjectID,
env.Name,
@ -35,14 +36,8 @@ func (d *SqlDb) CreateEnvironment(env db.Environment) (newEnv db.Environment, er
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newEnv = env
newEnv.ID = int(insertID)
newEnv.ID = insertID
return
}

View File

@ -31,7 +31,7 @@ func (d *SqlDb) getEventObjectName(evt db.Event) (string, error) {
}
var name sql.NullString
name, err = d.sql.SelectNullStr(query, args...)
name, err = d.selectNullStr(query, args...)
if err != nil {
return "", err
@ -56,7 +56,7 @@ func (d *SqlDb) getEvents(q squirrel.SelectBuilder, params db.RetrieveQueryParam
return
}
_, err = d.sql.Select(&events, query, args...)
_, err = d.selectAll(&events, query, args...)
if err != nil {
return

View File

@ -50,7 +50,8 @@ func (d *SqlDb) UpdateInventory(inventory db.Inventory) error {
}
func (d *SqlDb) CreateInventory(inventory db.Inventory) (newInventory db.Inventory, err error) {
res, err := d.exec(
insertID, err := d.insert(
"id",
"insert into project__inventory set project_id=?, name=?, type=?, key_id=?, ssh_key_id=?, inventory=?",
inventory.ProjectID,
inventory.Name,
@ -63,13 +64,8 @@ func (d *SqlDb) CreateInventory(inventory db.Inventory) (newInventory db.Invento
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newInventory = inventory
newInventory.ID = int(insertID)
newInventory.ID = insertID
return
}

View File

@ -9,18 +9,17 @@ import (
func (d *SqlDb) CreateProject(project db.Project) (newProject db.Project, err error) {
project.Created = time.Now()
res, err := d.exec("insert into project(name, created) values (?, ?)", project.Name, project.Created)
if err != nil {
return
}
insertId, err := d.insert(
"id",
"insert into project(name, created) values (?, ?)",
project.Name, project.Created)
insertId, err := res.LastInsertId()
if err != nil {
return
}
newProject = project
newProject.ID = int(insertId)
newProject.ID = insertId
return
}
@ -36,7 +35,7 @@ func (d *SqlDb) GetProjects(userID int) (projects []db.Project, err error) {
return
}
_, err = d.sql.Select(&projects, query, args...)
_, err = d.selectAll(&projects, query, args...)
return
}
@ -73,7 +72,7 @@ func (d *SqlDb) DeleteProject(projectID int) error {
}
for _, statement := range statements {
_, err = tx.Exec(statement, projectID)
_, err = tx.Exec(d.prepareQuery(statement), projectID)
if err != nil {
err = tx.Rollback()

View File

@ -46,7 +46,7 @@ func (d *SqlDb) GetRepositories(projectID int, params db.RetrieveQueryParams) (r
return
}
_, err = d.sql.Select(&repositories, query, args...)
_, err = d.selectAll(&repositories, query, args...)
return
}
@ -63,7 +63,8 @@ func (d *SqlDb) UpdateRepository(repository db.Repository) error {
}
func (d *SqlDb) CreateRepository(repository db.Repository) (newRepo db.Repository, err error) {
res, err := d.exec(
insertID, err := d.insert(
"id",
"insert into project__repository(project_id, git_url, ssh_key_id, name) values (?, ?, ?, ?)",
repository.ProjectID,
repository.GitURL,
@ -74,13 +75,8 @@ func (d *SqlDb) CreateRepository(repository db.Repository) (newRepo db.Repositor
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newRepo = repository
newRepo.ID = int(insertID)
newRepo.ID = insertID
return
}

View File

@ -56,7 +56,7 @@ func (d *SqlDb) TouchSession(userID int, sessionID int) error {
}
func (d *SqlDb) GetAPITokens(userID int) (tokens []db.APIToken, err error) {
_, err = d.sql.Select(&tokens, d.prepareQuery("select * from user__token where user_id=?"), userID)
_, err = d.selectAll(&tokens, d.prepareQuery("select * from user__token where user_id=?"), userID)
if err == sql.ErrNoRows {
err = db.ErrNotFound

View File

@ -32,10 +32,10 @@ func (d *SqlDb) CreateTaskOutput(output db.TaskOutput) (db.TaskOutput, error) {
}
func (d *SqlDb) getTasks(projectID int, templateID* int, params db.RetrieveQueryParams) (tasks []db.TaskWithTpl, err error) {
q := squirrel.Select("task.*, tpl.playbook as tpl_playbook, user.name as user_name, tpl.alias as tpl_alias").
q := squirrel.Select("task.*, tpl.playbook as tpl_playbook, `user`.name as user_name, tpl.alias as tpl_alias").
From("task").
Join("project__template as tpl on task.template_id=tpl.id").
LeftJoin("user on task.user_id=user.id").
LeftJoin("`user` on task.user_id=`user`.id").
OrderBy("task.created desc, id desc")
if templateID == nil {
@ -50,7 +50,7 @@ func (d *SqlDb) getTasks(projectID int, templateID* int, params db.RetrieveQuery
query, args, _ := q.ToSql()
_, err = d.sql.Select(&tasks, query, args...)
_, err = d.selectAll(&tasks, query, args...)
return
}
@ -110,7 +110,7 @@ func (d *SqlDb) GetTaskOutputs(projectID int, taskID int) (output []db.TaskOutpu
return
}
_, err = d.sql.Select(&output,
_, err = d.selectAll(&output,
"select task_id, task, time, output from task__output where task_id=? order by time asc",
taskID)
return

View File

@ -7,7 +7,9 @@ import (
)
func (d *SqlDb) CreateTemplate(template db.Template) (newTemplate db.Template, err error) {
res, err := d.exec("insert into project__template set ssh_key_id=?, project_id=?, inventory_id=?, repository_id=?, environment_id=?, alias=?, playbook=?, arguments=?, override_args=?",
insertID, err := d.insert(
"id",
"insert into project__template set ssh_key_id=?, project_id=?, inventory_id=?, repository_id=?, environment_id=?, alias=?, playbook=?, arguments=?, override_args=?",
template.SSHKeyID,
template.ProjectID,
template.InventoryID,
@ -22,13 +24,8 @@ func (d *SqlDb) CreateTemplate(template db.Template) (newTemplate db.Template, e
return
}
insertID, err := res.LastInsertId()
if err != nil {
return
}
newTemplate = template
newTemplate.ID = int(insertID)
newTemplate.ID = insertID
return
}
@ -96,7 +93,7 @@ func (d *SqlDb) GetTemplates(projectID int, params db.RetrieveQueryParams) (temp
return
}
_, err = d.sql.Select(&templates, query, args...)
_, err = d.selectAll(&templates, query, args...)
return
}

View File

@ -136,7 +136,7 @@ func (d *SqlDb) GetProjectUser(projectID, userID int) (db.ProjectUser, error) {
func (d *SqlDb) GetProjectUsers(projectID int, params db.RetrieveQueryParams) (users []db.User, err error) {
q := squirrel.Select("u.*").Column("pu.admin").
From("project__user as pu").
LeftJoin("user as u on pu.user_id=u.id").
LeftJoin("`user` as u on pu.user_id=u.id").
Where("pu.project_id=?", projectID)
sortDirection := "ASC"
@ -159,7 +159,7 @@ func (d *SqlDb) GetProjectUsers(projectID int, params db.RetrieveQueryParams) (u
return
}
_, err = d.sql.Select(&users, query, args...)
_, err = d.selectAll(&users, query, args...)
return
}
@ -199,7 +199,7 @@ func (d *SqlDb) GetUsers(params db.RetrieveQueryParams) (users []db.User, err er
return
}
_, err = d.sql.Select(&users, query, args...)
_, err = d.selectAll(&users, query, args...)
return
}