moar refactor

- c.MustGet( -> context.Get(r,
- c.Get( -> context.GetOk(r,
This commit is contained in:
Matej Kramny 2017-02-22 14:21:52 -08:00
parent 2f16f70e98
commit 1ddfcd5b5f
12 changed files with 78 additions and 77 deletions

View File

@ -8,14 +8,14 @@ import (
)
func getEvents(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
q := squirrel.Select("event.*, p.name as project_name").
From("event").
LeftJoin("project as p on event.project_id=p.id").
OrderBy("created desc")
projectObj, exists := c.Get("project")
projectObj, exists := context.GetOk(r, "project")
if exists == true {
// limit query to project
project := projectObj.(models.Project)

View File

@ -11,7 +11,7 @@ import (
)
func EnvironmentMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
envID, err := util.GetIntParam("environment_id", c)
if err != nil {
return
@ -38,7 +38,7 @@ func EnvironmentMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetEnvironment(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var env []models.Environment
q := squirrel.Select("*").
@ -55,9 +55,9 @@ func GetEnvironment(w http.ResponseWriter, r *http.Request) {
}
func UpdateEnvironment(w http.ResponseWriter, r *http.Request) {
oldEnv := c.MustGet("environment").(models.Environment)
oldEnv := context.Get(r, "environment").(models.Environment)
var env models.Environment
if err := c.Bind(&env); err != nil {
if err := mulekick.Bind(w, r, &env); err != nil {
return
}
@ -69,10 +69,10 @@ func UpdateEnvironment(w http.ResponseWriter, r *http.Request) {
}
func AddEnvironment(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var env models.Environment
if err := c.Bind(&env); err != nil {
if err := mulekick.Bind(w, r, &env); err != nil {
return
}
@ -99,7 +99,7 @@ func AddEnvironment(w http.ResponseWriter, r *http.Request) {
}
func RemoveEnvironment(w http.ResponseWriter, r *http.Request) {
env := c.MustGet("environment").(models.Environment)
env := context.Get(r, "environment").(models.Environment)
templatesC, err := database.Mysql.SelectInt("select count(1) from project__template where project_id=? and environment_id=?", env.ProjectID, env.ID)
if err != nil {

View File

@ -11,7 +11,7 @@ import (
)
func InventoryMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
inventoryID, err := util.GetIntParam("inventory_id", c)
if err != nil {
return
@ -38,7 +38,7 @@ func InventoryMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetInventory(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var inv []models.Inventory
query, args, _ := squirrel.Select("*").
@ -54,7 +54,7 @@ func GetInventory(w http.ResponseWriter, r *http.Request) {
}
func AddInventory(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var inventory struct {
Name string `json:"name" binding:"required"`
KeyID *int `json:"key_id"`
@ -63,7 +63,7 @@ func AddInventory(w http.ResponseWriter, r *http.Request) {
Inventory string `json:"inventory"`
}
if err := c.Bind(&inventory); err != nil {
if err := mulekick.Bind(w, r, &inventory); err != nil {
return
}
@ -98,7 +98,7 @@ func AddInventory(w http.ResponseWriter, r *http.Request) {
}
func UpdateInventory(w http.ResponseWriter, r *http.Request) {
oldInventory := c.MustGet("inventory").(models.Inventory)
oldInventory := context.Get(r, "inventory").(models.Inventory)
var inventory struct {
Name string `json:"name" binding:"required"`
@ -108,7 +108,7 @@ func UpdateInventory(w http.ResponseWriter, r *http.Request) {
Inventory string `json:"inventory"`
}
if err := c.Bind(&inventory); err != nil {
if err := mulekick.Bind(w, r, &inventory); err != nil {
return
}
@ -139,7 +139,7 @@ func UpdateInventory(w http.ResponseWriter, r *http.Request) {
}
func RemoveInventory(w http.ResponseWriter, r *http.Request) {
inventory := c.MustGet("inventory").(models.Inventory)
inventory := context.Get(r, "inventory").(models.Inventory)
templatesC, err := database.Mysql.SelectInt("select count(1) from project__template where project_id=? and inventory_id=?", inventory.ProjectID, inventory.ID)
if err != nil {

View File

@ -11,7 +11,7 @@ import (
)
func KeyMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
keyID, err := util.GetIntParam("key_id", c)
if err != nil {
return
@ -32,7 +32,7 @@ func KeyMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetKeys(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var keys []models.AccessKey
q := squirrel.Select("id, name, type, project_id, `key`, removed").
@ -52,10 +52,10 @@ func GetKeys(w http.ResponseWriter, r *http.Request) {
}
func AddKey(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var key models.AccessKey
if err := c.Bind(&key); err != nil {
if err := mulekick.Bind(w, r, &key); err != nil {
return
}
@ -102,9 +102,9 @@ func AddKey(w http.ResponseWriter, r *http.Request) {
func UpdateKey(w http.ResponseWriter, r *http.Request) {
var key models.AccessKey
oldKey := c.MustGet("accessKey").(models.AccessKey)
oldKey := context.Get(r, "accessKey").(models.AccessKey)
if err := c.Bind(&key); err != nil {
if err := mulekick.Bind(w, r, &key); err != nil {
return
}
@ -152,7 +152,7 @@ func UpdateKey(w http.ResponseWriter, r *http.Request) {
}
func RemoveKey(w http.ResponseWriter, r *http.Request) {
key := c.MustGet("accessKey").(models.AccessKey)
key := context.Get(r, "accessKey").(models.AccessKey)
templatesC, err := database.Mysql.SelectInt("select count(1) from project__template where project_id=? and ssh_key_id=?", *key.ProjectID, key.ID)
if err != nil {

View File

@ -11,7 +11,7 @@ import (
)
func ProjectMiddleware(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
projectID, err := util.GetIntParam("project_id", c)
if err != nil {
@ -40,12 +40,12 @@ func ProjectMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetProject(w http.ResponseWriter, r *http.Request) {
c.JSON(200, c.MustGet("project"))
c.JSON(200, context.Get(r, "project"))
}
func MustBeAdmin(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
user := c.MustGet("user").(*models.User)
project := context.Get(r, "project").(models.Project)
user := context.Get(r, "user").(*models.User)
userC, err := database.Mysql.SelectInt("select count(1) from project__user as pu join user as u on pu.user_id=u.id where pu.user_id=? and pu.project_id=? and pu.admin=1", user.ID, project.ID)
if err != nil {
@ -59,12 +59,12 @@ func MustBeAdmin(w http.ResponseWriter, r *http.Request) {
}
func UpdateProject(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var body struct {
Name string `json:"name"`
}
if err := c.Bind(&body); err != nil {
if err := mulekick.Bind(w, r, &body); err != nil {
return
}
@ -76,7 +76,7 @@ func UpdateProject(w http.ResponseWriter, r *http.Request) {
}
func DeleteProject(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
tx, err := database.Mysql.Begin()
if err != nil {

View File

@ -8,7 +8,7 @@ import (
)
func GetProjects(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
query, args, _ := squirrel.Select("p.*").
From("project as p").
@ -27,9 +27,9 @@ func GetProjects(w http.ResponseWriter, r *http.Request) {
func AddProject(w http.ResponseWriter, r *http.Request) {
var body models.Project
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
if err := c.Bind(&body); err != nil {
if err := mulekick.Bind(w, r, &body); err != nil {
return
}

View File

@ -23,7 +23,7 @@ func clearRepositoryCache(repository models.Repository) error {
}
func RepositoryMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
repositoryID, err := util.GetIntParam("repository_id", c)
if err != nil {
return
@ -44,7 +44,7 @@ func RepositoryMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetRepositories(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var repos []models.Repository
query, args, _ := squirrel.Select("*").
@ -61,14 +61,14 @@ func GetRepositories(w http.ResponseWriter, r *http.Request) {
}
func AddRepository(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var repository struct {
Name string `json:"name" binding:"required"`
GitUrl string `json:"git_url" binding:"required"`
SshKeyID int `json:"ssh_key_id" binding:"required"`
}
if err := c.Bind(&repository); err != nil {
if err := mulekick.Bind(w, r, &repository); err != nil {
return
}
@ -95,13 +95,13 @@ func AddRepository(w http.ResponseWriter, r *http.Request) {
}
func UpdateRepository(w http.ResponseWriter, r *http.Request) {
oldRepo := c.MustGet("repository").(models.Repository)
oldRepo := context.Get(r, "repository").(models.Repository)
var repository struct {
Name string `json:"name" binding:"required"`
GitUrl string `json:"git_url" binding:"required"`
SshKeyID int `json:"ssh_key_id" binding:"required"`
}
if err := c.Bind(&repository); err != nil {
if err := mulekick.Bind(w, r, &repository); err != nil {
return
}
@ -128,7 +128,7 @@ func UpdateRepository(w http.ResponseWriter, r *http.Request) {
}
func RemoveRepository(w http.ResponseWriter, r *http.Request) {
repository := c.MustGet("repository").(models.Repository)
repository := context.Get(r, "repository").(models.Repository)
templatesC, err := database.Mysql.SelectInt("select count(1) from project__template where project_id=? and repository_id=?", repository.ProjectID, repository.ID)
if err != nil {

View File

@ -12,7 +12,7 @@ import (
)
func TemplatesMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
templateID, err := util.GetIntParam("template_id", c)
if err != nil {
return
@ -33,7 +33,7 @@ func TemplatesMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetTemplates(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var templates []models.Template
q := squirrel.Select("*").
@ -50,10 +50,10 @@ func GetTemplates(w http.ResponseWriter, r *http.Request) {
}
func AddTemplate(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var template models.Template
if err := c.Bind(&template); err != nil {
if err := mulekick.Bind(w, r, &template); err != nil {
return
}
@ -84,10 +84,10 @@ func AddTemplate(w http.ResponseWriter, r *http.Request) {
}
func UpdateTemplate(w http.ResponseWriter, r *http.Request) {
oldTemplate := c.MustGet("template").(models.Template)
oldTemplate := context.Get(r, "template").(models.Template)
var template models.Template
if err := c.Bind(&template); err != nil {
if err := mulekick.Bind(w, r, &template); err != nil {
return
}
@ -110,7 +110,7 @@ func UpdateTemplate(w http.ResponseWriter, r *http.Request) {
}
func RemoveTemplate(w http.ResponseWriter, r *http.Request) {
tpl := c.MustGet("template").(models.Template)
tpl := context.Get(r, "template").(models.Template)
if _, err := database.Mysql.Exec("delete from project__template where id=?", tpl.ID); err != nil {
panic(err)

View File

@ -12,7 +12,7 @@ import (
)
func UserMiddleware(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
userID, err := util.GetIntParam("user_id", c)
if err != nil {
return
@ -33,7 +33,7 @@ func UserMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetUsers(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var users []struct {
models.User
Admin bool `db:"admin" json:"admin"`
@ -53,13 +53,13 @@ func GetUsers(w http.ResponseWriter, r *http.Request) {
}
func AddUser(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
var user struct {
UserID int `json:"user_id" binding:"required"`
Admin bool `json:"admin"`
}
if err := c.Bind(&user); err != nil {
if err := mulekick.Bind(w, r, &user); err != nil {
return
}
@ -82,8 +82,8 @@ func AddUser(w http.ResponseWriter, r *http.Request) {
}
func RemoveUser(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
user := c.MustGet("projectUser").(models.User)
project := context.Get(r, "project").(models.Project)
user := context.Get(r, "projectUser").(models.User)
if _, err := database.Mysql.Exec("delete from project__user where user_id=? and project_id=?", user.ID, project.ID); err != nil {
panic(err)
@ -104,8 +104,8 @@ func RemoveUser(w http.ResponseWriter, r *http.Request) {
}
func MakeUserAdmin(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
user := c.MustGet("projectUser").(models.User)
project := context.Get(r, "project").(models.Project)
user := context.Get(r, "projectUser").(models.User)
admin := 1
if r.Method == "DELETE" {

View File

@ -12,11 +12,11 @@ import (
)
func AddTask(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
user := c.MustGet("user").(*models.User)
project := context.Get(r, "project").(models.Project)
user := context.Get(r, "user").(*models.User)
var taskObj models.Task
if err := c.Bind(&taskObj); err != nil {
if err := mulekick.Bind(w, r, &taskObj); err != nil {
return
}
@ -48,7 +48,7 @@ func AddTask(w http.ResponseWriter, r *http.Request) {
}
func GetAll(w http.ResponseWriter, r *http.Request) {
project := c.MustGet("project").(models.Project)
project := context.Get(r, "project").(models.Project)
query, args, _ := squirrel.Select("task.*, tpl.playbook as tpl_playbook, user.name as user_name, tpl.alias as tpl_alias").
From("task").
@ -88,7 +88,7 @@ func GetTaskMiddleware(w http.ResponseWriter, r *http.Request) {
}
func GetTaskOutput(w http.ResponseWriter, r *http.Request) {
task := c.MustGet("task").(models.Task)
task := context.Get(r, "task").(models.Task)
var output []models.TaskOutput
if _, err := database.Mysql.Select(&output, "select * from task__output where task_id=? order by time asc", task.ID); err != nil {
@ -99,7 +99,7 @@ func GetTaskOutput(w http.ResponseWriter, r *http.Request) {
}
func RemoveTask(w http.ResponseWriter, r *http.Request) {
task := c.MustGet("task").(models.Task)
task := context.Get(r, "task").(models.Task)
statements := []string{
"delete from task__output where task_id=?",

View File

@ -13,16 +13,16 @@ import (
)
func getUser(w http.ResponseWriter, r *http.Request) {
if u, exists := c.Get("_user"); exists {
if u, exists := context.GetOk(r, "_user"); exists {
c.JSON(200, u)
return
}
c.JSON(200, c.MustGet("user"))
c.JSON(200, context.Get(r, "user"))
}
func getAPITokens(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
var tokens []models.APIToken
if _, err := database.Mysql.Select(&tokens, "select * from user__token where user_id=?", user.ID); err != nil {
@ -33,7 +33,7 @@ func getAPITokens(w http.ResponseWriter, r *http.Request) {
}
func createAPIToken(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
tokenID := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, tokenID); err != nil {
panic(err)
@ -54,7 +54,7 @@ func createAPIToken(w http.ResponseWriter, r *http.Request) {
}
func expireAPIToken(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("user").(*models.User)
user := context.Get(r, "user").(*models.User)
tokenID := c.Param("token_id")
res, err := database.Mysql.Exec("update user__token set expired=1 where id=? and user_id=?", tokenID, user.ID)

View File

@ -2,12 +2,14 @@ package api
import (
"database/sql"
"net/http"
"time"
database "github.com/ansible-semaphore/semaphore/db"
"github.com/ansible-semaphore/semaphore/models"
"github.com/ansible-semaphore/semaphore/util"
"github.com/gin-gonic/gin"
"github.com/castawaylabs/mulekick"
"github.com/gorilla/context"
"golang.org/x/crypto/bcrypt"
)
@ -17,12 +19,12 @@ func getUsers(w http.ResponseWriter, r *http.Request) {
panic(err)
}
c.JSON(200, users)
mulekick.WriteJSON(w, http.StatusOK, users)
}
func addUser(w http.ResponseWriter, r *http.Request) {
var user models.User
if err := c.Bind(&user); err != nil {
if err := mulekick.Bind(w, r, &user); err != nil {
return
}
@ -32,7 +34,7 @@ func addUser(w http.ResponseWriter, r *http.Request) {
panic(err)
}
c.JSON(201, user)
mulekick.WriteJSON(w, http.StatusCreated, user)
}
func getUserMiddleware(w http.ResponseWriter, r *http.Request) {
@ -51,15 +53,14 @@ func getUserMiddleware(w http.ResponseWriter, r *http.Request) {
panic(err)
}
c.Set("_user", user)
c.Next()
context.Set(r, "_user", user)
}
func updateUser(w http.ResponseWriter, r *http.Request) {
oldUser := c.MustGet("_user").(models.User)
oldUser := context.Get(r, "_user").(models.User)
var user models.User
if err := c.Bind(&user); err != nil {
if err := mulekick.Bind(w, r, &user); err != nil {
return
}
@ -71,12 +72,12 @@ func updateUser(w http.ResponseWriter, r *http.Request) {
}
func updateUserPassword(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("_user").(models.User)
user := context.Get(r, "_user").(models.User)
var pwd struct {
Pwd string `json:"password"`
}
if err := c.Bind(&pwd); err != nil {
if err := mulekick.Bind(w, r, &pwd); err != nil {
return
}
@ -89,7 +90,7 @@ func updateUserPassword(w http.ResponseWriter, r *http.Request) {
}
func deleteUser(w http.ResponseWriter, r *http.Request) {
user := c.MustGet("_user").(models.User)
user := context.Get(r, "_user").(models.User)
if _, err := database.Mysql.Exec("delete from project__user where user_id=?", user.ID); err != nil {
panic(err)