feat(backup): use marshal/unmarshal function

This commit is contained in:
Denis Gukov 2024-10-07 14:35:20 +05:00
parent 775de44489
commit 4011f358b0
10 changed files with 82 additions and 65 deletions

View File

@ -87,7 +87,8 @@ func WriteJSON(w http.ResponseWriter, code int, out interface{}) {
w.WriteHeader(code)
if err := json.NewEncoder(w).Encode(out); err != nil {
panic(err)
log.Error(err)
debug.PrintStack()
}
}

View File

@ -1,7 +1,9 @@
package projects
import (
"io"
"net/http"
"strings"
"github.com/ansible-semaphore/semaphore/api/helpers"
"github.com/ansible-semaphore/semaphore/db"
@ -21,27 +23,47 @@ func GetBackup(w http.ResponseWriter, r *http.Request) {
helpers.WriteError(w, err)
return
}
helpers.WriteJSON(w, http.StatusOK, backup)
str, err := backup.Marshal()
if err != nil {
helpers.WriteError(w, err)
return
}
w.Header().Set("content-type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(str))
}
func Restore(w http.ResponseWriter, r *http.Request) {
user := context.Get(r, "user").(*db.User)
var backup projectService.BackupFormat
var p *db.Project
var err error
if !helpers.Bind(w, r, &backup) {
helpers.WriteJSON(w, http.StatusBadRequest, backup)
return
}
store := helpers.Store(r)
if err = backup.Verify(); err != nil {
buf := new(strings.Builder)
if _, err := io.Copy(buf, r.Body); err != nil {
log.Error(err)
helpers.WriteError(w, err)
return
}
if p, err = backup.Restore(*user, store); err != nil {
if err := backup.Unmarshal(buf.String()); err != nil {
log.Error(err)
helpers.WriteError(w, err)
return
}
store := helpers.Store(r)
if err := backup.Verify(); err != nil {
log.Error(err)
helpers.WriteError(w, err)
return
}
var p *db.Project
p, err := backup.Restore(*user, store)
if err != nil {
log.Error(err)
helpers.WriteError(w, err)
return

View File

@ -47,6 +47,8 @@ type AccessKey struct {
// UserID is an ID of user which owns the access key.
UserID *int `db:"user_id" json:"-" backup:"-"`
Empty bool `db:"-" json:"empty"`
}
type LoginPassword struct {

View File

@ -78,6 +78,10 @@ func (d *BoltDb) GetTemplates(projectID int, filter db.TemplateFilter, params db
err = d.apply(projectID, db.TaskProps, db.RetrieveQueryParams{}, func(i interface{}) error {
task := i.(db.Task)
if task.ProjectID != projectID {
return nil
}
tpl, ok := templatesMap[task.TemplateID]
if !ok {
return nil

View File

@ -3,7 +3,6 @@ package bolt
import (
"github.com/ansible-semaphore/semaphore/db"
"go.etcd.io/bbolt"
"slices"
)
func (d *BoltDb) GetTemplateVaults(projectID int, templateID int) (vaults []db.TemplateVault, err error) {
@ -13,8 +12,8 @@ func (d *BoltDb) GetTemplateVaults(projectID int, templateID int) (vaults []db.T
if err != nil {
return
}
for _, vault := range vaults {
err = db.FillTemplateVault(d, projectID, &vault)
for i := range vaults {
err = db.FillTemplateVault(d, projectID, &vaults[i])
if err != nil {
return
}
@ -40,39 +39,25 @@ func (d *BoltDb) UpdateTemplateVaults(projectID int, templateID int, vaults []db
var oldVaults []db.TemplateVault
oldVaults, err = d.GetTemplateVaults(projectID, templateID)
var vaultIDs []int
for _, vault := range vaults {
vault.ProjectID = projectID
vault.TemplateID = templateID
if vault.ID == 0 {
// Insert new vaults
var newTpl interface{}
newTpl, err = d.createObject(projectID, db.TemplateVaultProps, vault)
err = d.db.Update(func(tx *bbolt.Tx) error {
for _, vault := range oldVaults {
err = d.deleteObject(projectID, db.TemplateVaultProps, intObjectID(vault.ID), tx)
if err != nil {
return
return err
}
vaultIDs = append(vaultIDs, newTpl.(db.TemplateVault).ID)
} else {
// Update existing vaults
err = d.updateObject(projectID, db.TemplateVaultProps, vault)
vaultIDs = append(vaultIDs, vault.ID)
}
if err != nil {
return
}
}
// Delete missing vaults
for _, vault := range oldVaults {
if !slices.Contains(vaultIDs, vault.ID) {
err = d.db.Update(func(tx *bbolt.Tx) error {
return d.deleteObject(projectID, db.TemplateVaultProps, intObjectID(vault.ID), tx)
})
for _, vault := range vaults {
vault.ProjectID = projectID
vault.TemplateID = templateID
_, err = d.createObjectTx(tx, projectID, db.TemplateVaultProps, vault)
if err != nil {
return
return err
}
}
}
return nil
})
return
}

View File

@ -9,18 +9,15 @@ import (
func (d *SqlDb) GetTemplateVaults(projectID int, templateID int) (vaults []db.TemplateVault, err error) {
vaults = []db.TemplateVault{}
var vlts []db.TemplateVault
_, err = d.selectAll(&vlts, "select * from project__template_vault where project_id=? and template_id=?", projectID, templateID)
_, err = d.selectAll(&vaults, "select * from project__template_vault where project_id=? and template_id=?", projectID, templateID)
if err != nil {
return
}
for _, vault := range vlts {
vault := vault
err = db.FillTemplateVault(d, projectID, &vault)
for i := range vaults {
err = db.FillTemplateVault(d, projectID, &vaults[i])
if err != nil {
return
}
vaults = append(vaults, vault)
}
return
}

View File

@ -39,11 +39,12 @@ func marshalValue(v reflect.Value) (interface{}, error) {
}
tag := fieldType.Tag.Get("backup")
// Check if the field should be backed up
if tag == "-" {
continue // Skip fields with backup:"-"
} else if tag == "" {
}
// Check if the field should be backed up
if tag == "" {
// Get the field name from the "db" tag
tag = fieldType.Tag.Get("db")
if tag == "" || tag == "-" {
@ -57,6 +58,10 @@ func marshalValue(v reflect.Value) (interface{}, error) {
return nil, err
}
if value == nil {
continue
}
result[tag] = value
}
return result, nil
@ -251,18 +256,19 @@ func unmarshalStructWithBackupTags(data map[string]interface{}, v reflect.Value)
}
// Skip fields with backup:"-"
if backupTag := fieldType.Tag.Get("backup"); backupTag == "-" {
backupTag := fieldType.Tag.Get("backup")
if backupTag == "-" {
continue
}
// Determine the JSON key to use
var jsonKey string
backupTag := fieldType.Tag.Get("backup")
if backupTag != "" {
jsonKey = backupTag
} else {
dbTag := fieldType.Tag.Get("db")
if dbTag != "" {
if dbTag != "" && dbTag != "-" {
jsonKey = dbTag
} else {
continue // Skip if no backup or db tag

View File

@ -341,6 +341,14 @@ func (backup *BackupFormat) Restore(user db.User, store db.Store) (*db.Project,
return nil, err
}
if _, err = store.CreateProjectUser(db.ProjectUser{
ProjectID: newProject.ID,
UserID: user.ID,
Role: db.ProjectOwner,
}); err != nil {
return nil, err
}
b.meta = newProject
for i, o := range backup.Environments {
@ -383,6 +391,7 @@ func (backup *BackupFormat) Restore(user db.User, store db.Store) (*db.Project,
return nil, fmt.Errorf("error at templates[%d]: %s", i, err.Error())
}
}
for _, i := range deployTemplates {
o := backup.Templates[i]
if err := o.Restore(store, &b); err != nil {
@ -390,13 +399,5 @@ func (backup *BackupFormat) Restore(user db.User, store db.Store) (*db.Project,
}
}
if _, err = store.CreateProjectUser(db.ProjectUser{
ProjectID: newProject.ID,
UserID: user.ID,
Role: db.ProjectOwner,
}); err != nil {
return nil, err
}
return &newProject, nil
}

View File

@ -30,9 +30,8 @@ type LocalJob struct {
// Internal field
Process *os.Process
sshKeyInstallation db.AccessKeyInstallation
becomeKeyInstallation db.AccessKeyInstallation
sshKeyInstallation db.AccessKeyInstallation
becomeKeyInstallation db.AccessKeyInstallation
vaultFileInstallations map[string]db.AccessKeyInstallation
}

View File

@ -167,7 +167,7 @@ func (t *TaskRunner) run() {
t.SetStatus(task_logger.TaskSuccessStatus)
}
templates, err := t.pool.store.GetTemplates(t.Task.ProjectID, db.TemplateFilter{
tpls, err := t.pool.store.GetTemplates(t.Task.ProjectID, db.TemplateFilter{
BuildTemplateID: &t.Task.TemplateID,
AutorunOnly: true,
}, db.RetrieveQueryParams{})
@ -176,7 +176,7 @@ func (t *TaskRunner) run() {
return
}
for _, tpl := range templates {
for _, tpl := range tpls {
_, err = t.pool.AddTask(db.Task{
TemplateID: tpl.ID,
ProjectID: tpl.ProjectID,