diff --git a/server/handler/update_bill.go b/server/handler/update_bill.go new file mode 100644 index 0000000..e78499b --- /dev/null +++ b/server/handler/update_bill.go @@ -0,0 +1,145 @@ +package handler + +import ( + "fmt" + "net/http" + "strings" + "time" + + "github.com/gin-gonic/gin" + + "billai-server/model" + "billai-server/repository" +) + +// UpdateBillRequest 账单更新请求(字段均为可选) +type UpdateBillRequest struct { + Time *string `json:"time,omitempty"` + Category *string `json:"category,omitempty"` + Merchant *string `json:"merchant,omitempty"` + Description *string `json:"description,omitempty"` + IncomeExpense *string `json:"income_expense,omitempty"` + Amount *float64 `json:"amount,omitempty"` + PayMethod *string `json:"pay_method,omitempty"` + Status *string `json:"status,omitempty"` + Remark *string `json:"remark,omitempty"` +} + +type UpdateBillResponse struct { + Result bool `json:"result"` + Message string `json:"message,omitempty"` + Data *model.CleanedBill `json:"data,omitempty"` +} + +func parseBillTime(s string) (time.Time, error) { + s = strings.TrimSpace(s) + formats := []string{ + "2006-01-02 15:04:05", + "2006-01-02T15:04:05Z07:00", + "2006-01-02T15:04:05Z", + "2006-01-02", + } + for _, f := range formats { + if t, err := time.ParseInLocation(f, s, time.Local); err == nil { + return t, nil + } + } + return time.Time{}, fmt.Errorf("unsupported time format") +} + +// UpdateBill PATCH /api/bills/:id 更新清洗后的账单记录 +func UpdateBill(c *gin.Context) { + id := strings.TrimSpace(c.Param("id")) + if id == "" { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "缺少账单 ID"}) + return + } + + var req UpdateBillRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "参数解析失败: " + err.Error()}) + return + } + + updates := map[string]interface{}{} + + if req.Time != nil { + t, err := parseBillTime(*req.Time) + if err != nil { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "时间格式错误"}) + return + } + updates["time"] = t + } + + if req.Category != nil { + v := strings.TrimSpace(*req.Category) + if v == "" { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "分类不能为空"}) + return + } + updates["category"] = v + } + + if req.Merchant != nil { + v := strings.TrimSpace(*req.Merchant) + if v == "" { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "商家不能为空"}) + return + } + updates["merchant"] = v + } + + if req.Description != nil { + updates["description"] = strings.TrimSpace(*req.Description) + } + + if req.IncomeExpense != nil { + v := strings.TrimSpace(*req.IncomeExpense) + if v != "" && v != "收入" && v != "支出" { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "income_expense 只能是 收入 或 支出"}) + return + } + updates["income_expense"] = v + } + + if req.Amount != nil { + updates["amount"] = *req.Amount + } + + if req.PayMethod != nil { + updates["pay_method"] = strings.TrimSpace(*req.PayMethod) + } + + if req.Status != nil { + updates["status"] = strings.TrimSpace(*req.Status) + } + + if req.Remark != nil { + updates["remark"] = strings.TrimSpace(*req.Remark) + } + + if len(updates) == 0 { + c.JSON(http.StatusBadRequest, UpdateBillResponse{Result: false, Message: "没有可更新的字段"}) + return + } + updates["updated_at"] = time.Now() + + repo := repository.GetRepository() + if repo == nil { + c.JSON(http.StatusInternalServerError, UpdateBillResponse{Result: false, Message: "数据库未连接"}) + return + } + + updated, err := repo.UpdateCleanedBillByID(id, updates) + if err != nil { + if err == repository.ErrNotFound { + c.JSON(http.StatusNotFound, UpdateBillResponse{Result: false, Message: "账单不存在"}) + return + } + c.JSON(http.StatusInternalServerError, UpdateBillResponse{Result: false, Message: "更新失败: " + err.Error()}) + return + } + + c.JSON(http.StatusOK, UpdateBillResponse{Result: true, Message: "更新成功", Data: updated}) +} diff --git a/server/repository/errors.go b/server/repository/errors.go new file mode 100644 index 0000000..26e874a --- /dev/null +++ b/server/repository/errors.go @@ -0,0 +1,6 @@ +package repository + +import "errors" + +// ErrNotFound 表示目标记录不存在 +var ErrNotFound = errors.New("not found") diff --git a/server/repository/mongo/repository.go b/server/repository/mongo/repository.go index 26b7a8a..e208a74 100644 --- a/server/repository/mongo/repository.go +++ b/server/repository/mongo/repository.go @@ -7,6 +7,7 @@ import ( "time" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" @@ -383,6 +384,40 @@ func (r *Repository) GetMonthlyStats() ([]model.MonthlyStat, error) { return stats, nil } +// UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录 +func (r *Repository) UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) { + if r.cleanedCollection == nil { + return nil, fmt.Errorf("cleaned collection not initialized") + } + + oid, err := primitive.ObjectIDFromHex(id) + if err != nil { + return nil, fmt.Errorf("invalid id: %w", err) + } + + if len(updates) == 0 { + return nil, fmt.Errorf("no updates") + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + filter := bson.M{"_id": oid} + update := bson.M{"$set": updates} + opts := options.FindOneAndUpdate().SetReturnDocument(options.After) + + var updated model.CleanedBill + err = r.cleanedCollection.FindOneAndUpdate(ctx, filter, update, opts).Decode(&updated) + if err != nil { + if err == mongo.ErrNoDocuments { + return nil, repository.ErrNotFound + } + return nil, fmt.Errorf("update bill failed: %w", err) + } + + return &updated, nil +} + // GetClient 获取 MongoDB 客户端(用于兼容旧代码) func (r *Repository) GetClient() *mongo.Client { return r.client diff --git a/server/repository/repository.go b/server/repository/repository.go index 60e6d6c..191347e 100644 --- a/server/repository/repository.go +++ b/server/repository/repository.go @@ -43,6 +43,9 @@ type BillRepository interface { // GetBillsNeedReview 获取需要复核的账单 GetBillsNeedReview() ([]model.CleanedBill, error) + // UpdateCleanedBillByID 按 ID 更新清洗后的账单,并返回更新后的记录 + UpdateCleanedBillByID(id string, updates map[string]interface{}) (*model.CleanedBill, error) + // CountRawByField 按字段统计原始数据数量 CountRawByField(fieldName, value string) (int64, error) } diff --git a/server/router/router.go b/server/router/router.go index c1ac8e4..8b206e6 100644 --- a/server/router/router.go +++ b/server/router/router.go @@ -59,6 +59,9 @@ func setupAPIRoutes(r *gin.Engine) { // 账单查询 authed.GET("/bills", handler.ListBills) + // 编辑账单 + authed.PATCH("/bills/:id", handler.UpdateBill) + // 手动创建账单 authed.POST("/bills/manual", handler.CreateManualBills) diff --git a/web/src/lib/api.ts b/web/src/lib/api.ts index 40cb12e..766f8f8 100644 --- a/web/src/lib/api.ts +++ b/web/src/lib/api.ts @@ -249,6 +249,42 @@ export interface CleanedBill { review_level: string; } +// 更新账单 +export interface UpdateBillRequest { + time?: string; + category?: string; + merchant?: string; + description?: string; + income_expense?: string; + amount?: number; + pay_method?: string; + status?: string; + remark?: string; +} + +export interface UpdateBillResponse { + result: boolean; + message?: string; + data?: CleanedBill; +} + +export async function updateBill(id: string, patch: UpdateBillRequest): Promise { + const response = await apiFetch(`${API_BASE}/api/bills/${encodeURIComponent(id)}`, { + method: 'PATCH', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify(patch), + }); + + if (!response.ok) { + // keep same behavior as other API calls + throw new Error(`HTTP ${response.status}`); + } + + return response.json(); +} + // 账单列表请求参数 export interface FetchBillsParams { page?: number; diff --git a/web/src/lib/components/analysis/BillRecordsTable.svelte b/web/src/lib/components/analysis/BillRecordsTable.svelte index 93b0f6f..9ee707d 100644 --- a/web/src/lib/components/analysis/BillRecordsTable.svelte +++ b/web/src/lib/components/analysis/BillRecordsTable.svelte @@ -19,7 +19,7 @@ import Tag from '@lucide/svelte/icons/tag'; import FileText from '@lucide/svelte/icons/file-text'; import CreditCard from '@lucide/svelte/icons/credit-card'; - import type { BillRecord } from '$lib/api'; + import { updateBill, type BillRecord } from '$lib/api'; interface Props { records: BillRecord[]; @@ -53,6 +53,7 @@ let selectedRecord = $state(null); let selectedIndex = $state(-1); let isEditing = $state(false); + let isSaving = $state(false); let editForm = $state({ amount: '', merchant: '', @@ -135,8 +136,11 @@ } // 保存编辑 - function saveEdit() { + async function saveEdit() { if (!selectedRecord) return; + + if (isSaving) return; + isSaving = true; const original = { ...selectedRecord }; const updated: BillRecord = { @@ -147,23 +151,52 @@ description: editForm.description, payment_method: editForm.payment_method }; - - // 更新本地数据 - const idx = records.findIndex(r => - r.time === selectedRecord!.time && - r.merchant === selectedRecord!.merchant && - r.amount === selectedRecord!.amount - ); - if (idx !== -1) { - records[idx] = updated; - records = [...records]; // 触发响应式更新 + + try { + // 如果有后端 ID,则持久化更新 + const billId = (selectedRecord as unknown as { id?: string }).id; + if (billId) { + const resp = await updateBill(billId, { + merchant: editForm.merchant, + category: editForm.category, + amount: Number(editForm.amount), + description: editForm.description, + pay_method: editForm.payment_method, + }); + + if (resp.result && resp.data) { + // 将后端返回的 CleanedBill 映射为 BillRecord + updated.amount = String(resp.data.amount); + updated.merchant = resp.data.merchant; + updated.category = resp.data.category; + updated.description = resp.data.description || ''; + updated.payment_method = resp.data.pay_method || ''; + // 让时间展示更稳定:使用后端格式 + updated.time = resp.data.time; + } + } + + // 更新本地数据(fallback:按引用/关键字段查找) + const idx = records.findIndex(r => r === selectedRecord); + const finalIdx = idx !== -1 + ? idx + : records.findIndex(r => + r.time === selectedRecord!.time && + r.merchant === selectedRecord!.merchant && + r.amount === selectedRecord!.amount + ); + + if (finalIdx !== -1) { + records[finalIdx] = updated; + records = [...records]; + } + + selectedRecord = updated; + isEditing = false; + onUpdate?.(updated, original); + } finally { + isSaving = false; } - - selectedRecord = updated; - isEditing = false; - - // 通知父组件 - onUpdate?.(updated, original); } // 处理分类选择 @@ -335,7 +368,7 @@ - + @@ -465,9 +498,9 @@ 取消 - {:else} - {:else}