diff --git a/api/projects/tasks.go b/api/projects/tasks.go index d4c448c7..4e331c6f 100644 --- a/api/projects/tasks.go +++ b/api/projects/tasks.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "net/http" "strconv" + "time" ) // AddTask inserts a task into the database and returns a header or returns error @@ -222,7 +223,42 @@ func RemoveTask(w http.ResponseWriter, r *http.Request) { func GetTaskStats(w http.ResponseWriter, r *http.Request) { project := context.Get(r, "project").(db.Project) - stats, err := helpers.Store(r).GetTaskStats(project.ID, nil, db.TaskStatUnitDay, db.TaskFilter{}) + var tplID *int + if tpl := context.Get(r, "template"); tpl != nil { + id := tpl.(db.Template).ID + tplID = &id + } + + filter := db.TaskFilter{} + + if start := r.URL.Query().Get("start"); start != "" { + d, err := time.Parse("2006-01-02", start) + if err != nil { + helpers.WriteErrorStatus(w, "Invalid start date", http.StatusBadRequest) + return + } + filter.Start = &d + } + + if end := r.URL.Query().Get("end"); end != "" { + d, err := time.Parse("2006-01-02", end) + if err != nil { + helpers.WriteErrorStatus(w, "Invalid end date", http.StatusBadRequest) + return + } + filter.End = &d + } + + if userId := r.URL.Query().Get("user_id"); userId != "" { + u, err := strconv.Atoi(userId) + if err != nil { + helpers.WriteErrorStatus(w, "Invalid user_id", http.StatusBadRequest) + return + } + filter.UserID = &u + } + + stats, err := helpers.Store(r).GetTaskStats(project.ID, tplID, db.TaskStatUnitDay, filter) if err != nil { util.LogErrorWithFields(err, log.Fields{"error": "Bad request. Cannot get task stats from database"}) w.WriteHeader(http.StatusBadRequest) diff --git a/db/Store.go b/db/Store.go index db3aedff..16e2393c 100644 --- a/db/Store.go +++ b/db/Store.go @@ -99,8 +99,8 @@ const TaskStatUnitWeek TaskStatUnit = "week" const TaskStatUnitMonth TaskStatUnit = "month" type TaskFilter struct { - From *time.Time `json:"from"` - To *time.Time `json:"to"` + Start *time.Time `json:"start"` + End *time.Time `json:"end"` UserID *int `json:"user_id"` } diff --git a/db/sql/SqlDb.go b/db/sql/SqlDb.go index 7a8ab54c..09b043a4 100644 --- a/db/sql/SqlDb.go +++ b/db/sql/SqlDb.go @@ -763,6 +763,8 @@ func (d *SqlDb) GetObjectReferences(objectProps db.ObjectProps, referringObjectP func (d *SqlDb) GetTaskStats(projectID int, templateID *int, unit db.TaskStatUnit, filter db.TaskFilter) (stats []db.TaskStat, err error) { + stats = make([]db.TaskStat, 0) + if unit != db.TaskStatUnitDay { err = fmt.Errorf("only day unit is supported") return @@ -784,8 +786,12 @@ func (d *SqlDb) GetTaskStats(projectID int, templateID *int, unit db.TaskStatUni q = q.Where("template_id=?", *templateID) } - if filter.UserID != nil { - q = q.Where("user_id=?", *filter.UserID) + if filter.Start != nil { + q = q.Where("start>=?", *filter.Start) + } + + if filter.End != nil { + q = q.Where("end