feat: 支持账单编辑(PATCH /api/bills/:id)

This commit is contained in:
clz
2026-01-18 20:17:19 +08:00
parent 339b8afe98
commit f5afb0c135
8 changed files with 326 additions and 37 deletions

View File

@@ -0,0 +1,145 @@
package handler
import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"billai-server/model"
"billai-server/repository"
)
// UpdateBillRequest 账单更新请求(字段均为可选)
type UpdateBillRequest struct {
Time *string `json:"time,omitempty"`
Category *string `json:"category,omitempty"`
Merchant *string `json:"merchant,omitempty"`
Description *string `json:"description,omitempty"`
IncomeExpense *string `json:"income_expense,omitempty"`
Amount *float64 `json:"amount,omitempty"`
PayMethod *string `json:"pay_method,omitempty"`
Status *string `json:"status,omitempty"`
Remark *string `json:"remark,omitempty"`
}
type UpdateBillResponse struct {
Result bool `json:"result"`
Message string `json:"message,omitempty"`
Data *model.CleanedBill `json:"data,omitempty"`
}
func parseBillTime(s string) (time.Time, error) {
s = strings.TrimSpace(s)
formats := []string{
"2006-01-02 15:04:05",
"2006-01-02T15:04:05Z07:00",
"2006-01-02T15:04:05Z",
"2006-01-02",
}
for _, f := range formats {
if t, err := time.ParseInLocation(f, s, time.Local); err == nil {
return t, nil
}
}
return time.Time{}, fmt.Errorf("unsupported time format")
}
// UpdateBill PATCH /api/bills/:id 更新清洗后的账单记录
func UpdateBill(c *gin.Context) {
id := strings.TrimSpace(c.Param("id"))
if id == "" {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "缺少账单 ID"})
return
}
var req UpdateBillRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "参数解析失败: " + err.Error()})
return
}
updates := map[string]interface{}{}
if req.Time != nil {
t, err := parseBillTime(*req.Time)
if err != nil {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "时间格式错误"})
return
}
updates["time"] = t
}
if req.Category != nil {
v := strings.TrimSpace(*req.Category)
if v == "" {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "分类不能为空"})
return
}
updates["category"] = v
}
if req.Merchant != nil {
v := strings.TrimSpace(*req.Merchant)
if v == "" {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "商家不能为空"})
return
}
updates["merchant"] = v
}
if req.Description != nil {
updates["description"] = strings.TrimSpace(*req.Description)
}
if req.IncomeExpense != nil {
v := strings.TrimSpace(*req.IncomeExpense)
if v != "" && v != "收入" && v != "支出" {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "income_expense 只能是 收入 或 支出"})
return
}
updates["income_expense"] = v
}
if req.Amount != nil {
updates["amount"] = *req.Amount
}
if req.PayMethod != nil {
updates["pay_method"] = strings.TrimSpace(*req.PayMethod)
}
if req.Status != nil {
updates["status"] = strings.TrimSpace(*req.Status)
}
if req.Remark != nil {
updates["remark"] = strings.TrimSpace(*req.Remark)
}
if len(updates) == 0 {
c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "没有可更新的字段"})
return
}
updates["updated_at"] = time.Now()
repo := repository.GetRepository()
if repo == nil {
c.JSON(http.StatusInternalServerError, UpdateBillResponse{Result: false, Message: "数据库未连接"})
return
}
updated, err := repo.UpdateCleanedBillByID(id, updates)
if err != nil {
if err == repository.ErrNotFound {
c.JSON(http.StatusNotFound, UpdateBillResponse{Result: false, Message: "账单不存在"})
return
}
c.JSON(http.StatusInternalServerError, UpdateBillResponse{Result: false, Message: "更新失败: " + err.Error()})
return
}
c.JSON(http.StatusOK, UpdateBillResponse{Result: true, Message: "更新成功", Data: updated})
}

View File

@@ -0,0 +1,6 @@
package repository
import "errors"
// ErrNotFound 表示目标记录不存在
var ErrNotFound = errors.New("not found")

View File

@@ -7,6 +7,7 @@ import (
"time"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/bson/primitive"
"go.mongodb.org/mongo-driver/mongo"
"go.mongodb.org/mongo-driver/mongo/options"
@@ -383,6 +384,40 @@ func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) {
return stats, nil
}
// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录
func (r *Repository) UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) {
if r.cleanedCollection == nil {
return nil, fmt.Errorf("cleaned collection not initialized")
}
oid, err := primitive.ObjectIDFromHex(id)
if err != nil {
return nil, fmt.Errorf("invalid id: %w", err)
}
if len(updates) == 0 {
return nil, fmt.Errorf("no updates")
}
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
filter := bson.M{"_id": oid}
update := bson.M{"$set": updates}
opts := options.FindOneAndUpdate().SetReturnDocument(options.After)
var updated model.CleanedBill
err = r.cleanedCollection.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updated)
if err != nil {
if err == mongo.ErrNoDocuments {
return nil, repository.ErrNotFound
}
return nil, fmt.Errorf("update bill failed: %w", err)
}
return &updated, nil
}
// GetClient 获取 MongoDB 客户端(用于兼容旧代码)
func (r *Repository) GetClient() *mongo.Client {
return r.client

View File

@@ -43,6 +43,9 @@ type BillRepository interface {
// GetBillsNeedReview 获取需要复核的账单
GetBillsNeedReview() ([]model.CleanedBill, error)
// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录
UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error)
// CountRawByField 按字段统计原始数据数量
CountRawByField(fieldName, value string) (int64, error)
}

View File

@@ -59,6 +59,9 @@ func setupAPIRoutes(r *gin.Engine) {
// 账单查询
authed.GET("/bills", handler.ListBills)
// 编辑账单
authed.PATCH("/bills/:id", handler.UpdateBill)
// 手动创建账单
authed.POST("/bills/manual", handler.CreateManualBills)