diff --git a/api/auth.go b/api/auth.go index a7312187..a69e5818 100644 --- a/api/auth.go +++ b/api/auth.go @@ -114,20 +114,10 @@ func authentication(next http.Handler) http.Handler { func authenticationWithStore(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { store := helpers.Store(r) - var url = r.URL.String() - if !store.KeepConnection() { - err := store.Connect(url) - if err != nil { - panic(err) - } - } - - authenticationHandler(w, r) - - if !store.KeepConnection() { - _ = store.Close(url) - } + db.StoreSession(store, r.URL.String(), func() { + authenticationHandler(w, r) + }) next.ServeHTTP(w, r) }) diff --git a/api/router.go b/api/router.go index 7a7f855d..4bf56f12 100644 --- a/api/router.go +++ b/api/router.go @@ -3,6 +3,7 @@ package api import ( "fmt" "github.com/ansible-semaphore/semaphore/api/helpers" + "github.com/ansible-semaphore/semaphore/db" "net/http" "os" "strings" @@ -21,18 +22,9 @@ func StoreMiddleware(next http.Handler) http.Handler { store := helpers.Store(r) var url = r.URL.String() - if !store.KeepConnection() { - err := store.Connect(url) - if err != nil { - panic(err) - } - } - - next.ServeHTTP(w, r) - - if !store.KeepConnection() { - _ = store.Close(url) - } + db.StoreSession(store, url, func() { + next.ServeHTTP(w, r) + }) }) } diff --git a/api/sockets/handler.go b/api/sockets/handler.go index bdb4319c..f2a59814 100644 --- a/api/sockets/handler.go +++ b/api/sockets/handler.go @@ -104,8 +104,12 @@ func (c *connection) writePump() { // Handler is used by the router to handle the /ws endpoint func Handler(w http.ResponseWriter, r *http.Request) { + usr := context.Get(r, "user") + if usr == nil { + return + } - user := context.Get(r, "user").(*db.User) + user := usr.(*db.User) ws, err := upgrader.Upgrade(w, r, nil) if err != nil { panic(err) diff --git a/cli/cmd/root.go b/cli/cmd/root.go index b400421f..4bd24773 100644 --- a/cli/cmd/root.go +++ b/cli/cmd/root.go @@ -13,7 +13,6 @@ import ( "github.com/gorilla/context" "github.com/gorilla/handlers" "github.com/spf13/cobra" - "go.etcd.io/bbolt" "net/http" "os" ) @@ -88,7 +87,7 @@ func runService() { fmt.Println("Server is running") - if store.KeepConnection() { + if store.PermanentConnection() { defer store.Close("root") } else { store.Close("root") @@ -106,15 +105,17 @@ func createStore(token string) db.Store { store := factory.CreateStore() - if err := store.Connect(token); err != nil { - switch err { - case bbolt.ErrTimeout: - fmt.Println("\n BoltDB supports only one connection at a time. You should stop Semaphore to use CLI.") - default: - fmt.Println("\n Have you run `semaphore setup`?") - } - os.Exit(1) - } + store.Connect(token) + + //if err := store.Connect(token); err != nil { + // switch err { + // case bbolt.ErrTimeout: + // fmt.Println("\n BoltDB supports only one connection at a time. You should stop Semaphore to use CLI.") + // default: + // fmt.Println("\n Have you run `semaphore setup`?") + // } + // os.Exit(1) + //} err := db.Migrate(store) diff --git a/cli/cmd/setup.go b/cli/cmd/setup.go index c2a8724f..e87a1269 100644 --- a/cli/cmd/setup.go +++ b/cli/cmd/setup.go @@ -38,10 +38,7 @@ func doSetup() int { store := factory.CreateStore() defer store.Close("setup") - if err := store.Connect("setup"); err != nil { - fmt.Printf("Cannot connect to database!\n %v\n", err.Error()) - os.Exit(1) - } + store.Connect("setup") fmt.Println("Running db Migrations..") if err := db.Migrate(store); err != nil { diff --git a/database_test.boltdb b/database_test.boltdb deleted file mode 100644 index b5ca2227..00000000 Binary files a/database_test.boltdb and /dev/null differ diff --git a/db/Store.go b/db/Store.go index 19c5a8b3..4c9d624d 100644 --- a/db/Store.go +++ b/db/Store.go @@ -78,14 +78,14 @@ func (e *ValidationError) Error() string { type Store interface { // Connect connects to the database. - // token parameter used if KeepConnection returns false. - Connect(token string) error - Close(token string) error + // token parameter used if PermanentConnection returns false. + Connect(token string) + Close(token string) - // KeepConnection returns true if connection should be kept from start to finish of the app. + // PermanentConnection returns true if connection should be kept from start to finish of the app. // This mode is suitable for MySQL and Postgres but not for BoltDB. // For BoltDB we should reconnect for each request because BoltDB support only one connection at time. - KeepConnection() bool + PermanentConnection() bool // IsInitialized indicates is database already initialized, or it is empty. // The method is useful for creating required entities in database during first run. @@ -326,3 +326,15 @@ func (p ObjectProps) GetReferringFieldsFrom(t reflect.Type) (fields []string, er return } + +func StoreSession(store Store, token string, callback func()) { + if !store.PermanentConnection() { + store.Connect(token) + } + + callback() + + if !store.PermanentConnection() { + store.Close(token) + } +} diff --git a/db/bolt/BoltDb.go b/db/bolt/BoltDb.go index 9ca27168..12404ced 100644 --- a/db/bolt/BoltDb.go +++ b/db/bolt/BoltDb.go @@ -12,6 +12,7 @@ import ( "sort" "strconv" "strings" + "sync" "time" ) @@ -36,6 +37,7 @@ type BoltDb struct { Filename string db *bbolt.DB connections map[string]bool + mu sync.Mutex } type objectID interface { @@ -71,31 +73,28 @@ func (d *BoltDb) Migrate() error { return nil } -func (d *BoltDb) Connect(token string) error { +func (d *BoltDb) Connect(token string) { + d.mu.Lock() + defer d.mu.Unlock() + if d.connections == nil { d.connections = make(map[string]bool) } - fmt.Println("CONN " + token) - if _, exists := d.connections[token]; exists { - return fmt.Errorf("Connection " + token + " already exists") - } - - for k := range d.connections { - fmt.Println("- EXIST " + k) + panic(fmt.Errorf("Connection " + token + " already exists")) } if len(d.connections) > 0 { d.connections[token] = true - return nil + return } var filename string if d.Filename == "" { config, err := util.Config.GetDBConfig() if err != nil { - return err + panic(err) } filename = config.Hostname } else { @@ -108,25 +107,25 @@ func (d *BoltDb) Connect(token string) error { }) if err != nil { - return err + panic(err) } d.connections[token] = true - return nil } -func (d *BoltDb) Close(token string) error { - fmt.Println("CLOSE " + token) +func (d *BoltDb) Close(token string) { + d.mu.Lock() + defer d.mu.Unlock() _, exists := d.connections[token] if !exists { - panic(fmt.Errorf("can not close of connection closed")) + panic(fmt.Errorf("can not close closed connection " + token)) } if len(d.connections) > 1 { delete(d.connections, token) - return nil + return } err := d.db.Close() @@ -136,15 +135,9 @@ func (d *BoltDb) Close(token string) error { d.db = nil delete(d.connections, token) - - for k := range d.connections { - fmt.Println("- EXIST " + k) - } - - return nil } -func (d *BoltDb) KeepConnection() bool { +func (d *BoltDb) PermanentConnection() bool { return false } @@ -697,15 +690,12 @@ func (d *BoltDb) isObjectInUse(bucketID int, objProps db.ObjectProps, objID obje return } -func CreateTestStore() BoltDb { +func CreateTestStore() *BoltDb { r := rand.New(rand.NewSource(time.Now().UTC().UnixNano())) fn := "/tmp/test_semaphore_db_" + strconv.Itoa(r.Int()) store := BoltDb{ Filename: fn, } - err := store.Connect("test") - if err != nil { - panic(err) - } - return store + store.Connect("test") + return &store } diff --git a/db/bolt/Task_test.go b/db/bolt/Task_test.go index 8f62520d..be8b4e64 100644 --- a/db/bolt/Task_test.go +++ b/db/bolt/Task_test.go @@ -69,7 +69,7 @@ func TestTask_GetVersion(t *testing.T) { t.Fatal(err) } - version := deployTask.GetIncomingVersion(&store) + version := deployTask.GetIncomingVersion(store) if version == nil { t.Fatal() return @@ -79,7 +79,7 @@ func TestTask_GetVersion(t *testing.T) { return } - version = deploy2Task.GetIncomingVersion(&store) + version = deploy2Task.GetIncomingVersion(store) if version == nil { t.Fatal() return diff --git a/db/sql/SqlDb.go b/db/sql/SqlDb.go index cdf6c87a..90a242e7 100644 --- a/db/sql/SqlDb.go +++ b/db/sql/SqlDb.go @@ -247,38 +247,41 @@ func (d *SqlDb) deleteObject(projectID int, props db.ObjectProps, objectID int) objectID)) } -func (d *SqlDb) Close(token string) error { - return d.sql.Db.Close() +func (d *SqlDb) Close(token string) { + err := d.sql.Db.Close() + if err != nil { + panic(err) + } } -func (d *SqlDb) KeepConnection() bool { +func (d *SqlDb) PermanentConnection() bool { return true } -func (d *SqlDb) Connect(token string) error { +func (d *SqlDb) Connect(token string) { sqlDb, err := connect() if err != nil { - return err + panic(err) } if err := sqlDb.Ping(); err != nil { if err = createDb(); err != nil { - return err + panic(err) } sqlDb, err = connect() if err != nil { - return err + panic(err) } if err = sqlDb.Ping(); err != nil { - return err + panic(err) } } cfg, err := util.Config.GetDBConfig() if err != nil { - return err + panic(err) } var dialect gorp.Dialect @@ -303,8 +306,6 @@ func (d *SqlDb) Connect(token string) error { d.sql.AddTableWithName(db.Template{}, "project__template").SetKeys(true, "id") d.sql.AddTableWithName(db.User{}, "user").SetKeys(true, "id") d.sql.AddTableWithName(db.Session{}, "session").SetKeys(true, "id") - - return nil } func getSqlForTable(tableName string, p db.RetrieveQueryParams) (string, []interface{}, error) { diff --git a/services/schedules/pool.go b/services/schedules/pool.go index 775a15f3..4654bf1e 100644 --- a/services/schedules/pool.go +++ b/services/schedules/pool.go @@ -50,10 +50,11 @@ func (r ScheduleRunner) tryUpdateScheduleCommitHash(schedule db.Schedule) (updat } func (r ScheduleRunner) Run() { - if !r.pool.store.KeepConnection() { + if !r.pool.store.PermanentConnection() { r.pool.store.Connect("schedule") defer r.pool.store.Close("schedule") } + schedule, err := r.pool.store.GetSchedule(r.projectID, r.scheduleID) if err != nil { log.Error(err) diff --git a/services/tasks/pool.go b/services/tasks/pool.go index b1b2405e..9672507a 100644 --- a/services/tasks/pool.go +++ b/services/tasks/pool.go @@ -106,39 +106,31 @@ func (p *TaskPool) Run() { for { select { case record := <-p.logger: // new log message which should be put to database - if !record.task.pool.store.KeepConnection() { - err := record.task.pool.store.Connect("task " + strconv.Itoa(record.task.task.ID) + " output") - + db.StoreSession(p.store, "logger", func() { + _, err := p.store.CreateTaskOutput(db.TaskOutput{ + TaskID: record.task.task.ID, + Output: record.output, + Time: record.time, + }) if err != nil { log.Error(err) } - } - - _, err := record.task.pool.store.CreateTaskOutput(db.TaskOutput{ - TaskID: record.task.task.ID, - Output: record.output, - Time: record.time, }) - if !record.task.pool.store.KeepConnection() { - _ = record.task.pool.store.Close("task " + strconv.Itoa(record.task.task.ID) + " output") - } - - if err != nil { - log.Error(err) - } - case task := <-p.register: // new task created by API or schedule - p.queue = append(p.queue, task) - log.Debug(task) - msg := "Task " + strconv.Itoa(task.task.ID) + " added to queue" - task.Log(msg) - log.Info(msg) - task.updateStatus() + + db.StoreSession(p.store, "new task", func() { + p.queue = append(p.queue, task) + log.Debug(task) + msg := "Task " + strconv.Itoa(task.task.ID) + " added to queue" + task.Log(msg) + log.Info(msg) + task.updateStatus() + }) case <-ticker.C: // timer 5 seconds if len(p.queue) == 0 { - continue + break } //get TaskRunner from top of queue @@ -147,19 +139,22 @@ func (p *TaskPool) Run() { //delete failed TaskRunner from queue p.queue = p.queue[1:] log.Info("Task " + strconv.Itoa(t.task.ID) + " removed from queue") - continue + break } + if p.blocks(t) { //move blocked TaskRunner to end of queue p.queue = append(p.queue[1:], t) - continue + break } + log.Info("Set resource locker with TaskRunner " + strconv.Itoa(t.task.ID)) p.resourceLocker <- &resourceLock{lock: true, holder: t} if !t.prepared { go t.prepareRun() - continue + break } + go t.run() p.queue = p.queue[1:] log.Info("Task " + strconv.Itoa(t.task.ID) + " removed from queue") diff --git a/services/tasks/runner.go b/services/tasks/runner.go index 5055720e..e7690aca 100644 --- a/services/tasks/runner.go +++ b/services/tasks/runner.go @@ -154,12 +154,9 @@ func (t *TaskRunner) createTaskEvent() { func (t *TaskRunner) prepareRun() { t.prepared = false - if !t.pool.store.KeepConnection() { - err := t.pool.store.Connect("task " + strconv.Itoa(t.task.ID)) - - if err != nil { - t.panicOnError(err, "Fatal error inserting an event") - } + if !t.pool.store.PermanentConnection() { + t.pool.store.Connect("task " + strconv.Itoa(t.task.ID)) + defer t.pool.store.Close("task " + strconv.Itoa(t.task.ID)) } defer func() { @@ -168,10 +165,6 @@ func (t *TaskRunner) prepareRun() { t.pool.resourceLocker <- &resourceLock{lock: false, holder: t} t.createTaskEvent() - - if !t.pool.store.KeepConnection() { - t.pool.store.Close("task " + strconv.Itoa(t.task.ID)) - } }() t.Log("Preparing: " + strconv.Itoa(t.task.ID)) diff --git a/services/tasks/runner_test.go b/services/tasks/runner_test.go index 0ea55f94..6d02139f 100644 --- a/services/tasks/runner_test.go +++ b/services/tasks/runner_test.go @@ -76,10 +76,8 @@ func TestPopulateDetails(t *testing.T) { store := bolt.BoltDb{ Filename: fn, } - err := store.Connect("") - if err != nil { - t.Fatal(err) - } + + store.Connect("") proj, err := store.CreateProject(db.Project{}) if err != nil {