package main import ( "context" "database/sql" "encoding/json" "fmt" "log" "os" "path/filepath" "strings" _ "github.com/mattn/go-sqlite3" "github.com/modelcontextprotocol/go-sdk/mcp" ) var db *sql.DB // ==================== 数据库初始化 ==================== func initDB() error { // 获取可执行文件所在目录 execPath, err := os.Executable() if err != nil { execPath = "." } dbPath := filepath.Join(filepath.Dir(execPath), "demo.db") // 如果数据库不存在,使用当前目录 if _, err := os.Stat(dbPath); os.IsNotExist(err) { dbPath = "demo.db" } db, err = sql.Open("sqlite3", dbPath) if err != nil { return err } // 创建示例表 _, err = db.Exec(` CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, email TEXT UNIQUE, age INTEGER, created_at DATETIME DEFAULT CURRENT_TIMESTAMP ); CREATE TABLE IF NOT EXISTS products ( id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, price REAL, stock INTEGER DEFAULT 0 ); `) if err != nil { return err } // 插入示例数据(如果表为空) var count int db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count) if count == 0 { _, err = db.Exec(` INSERT INTO users (name, email, age) VALUES ('张三', 'zhangsan@example.com', 28), ('李四', 'lisi@example.com', 32), ('王五', 'wangwu@example.com', 25); INSERT INTO products (name, price, stock) VALUES ('iPhone 15', 6999.00, 100), ('MacBook Pro', 14999.00, 50), ('AirPods Pro', 1899.00, 200); `) if err != nil { return err } } return nil } // ==================== 工具输入结构 ==================== // ListTablesInput - 列出所有表 type ListTablesInput struct{} // QueryInput - 执行 SQL 查询 type QueryInput struct { SQL string `json:"sql" mcp:"SQL query to execute (SELECT only for safety)"` } // GetUserInput - 获取用户 type GetUserInput struct { ID int `json:"id" mcp:"User ID to retrieve"` } // AddUserInput - 添加用户 type AddUserInput struct { Name string `json:"name" mcp:"User name"` Email string `json:"email" mcp:"User email address"` Age int `json:"age" mcp:"User age"` } // DeleteUserInput - 删除用户 type DeleteUserInput struct { ID int `json:"id" mcp:"User ID to delete"` } // ==================== 辅助函数 ==================== func textResult(text string) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: text}, }, }, nil, nil } func errorResult(text string) (*mcp.CallToolResult, any, error) { return &mcp.CallToolResult{ Content: []mcp.Content{ &mcp.TextContent{Text: text}, }, IsError: true, }, nil, nil } // ==================== 工具实现 ==================== // ListTables - 列出数据库中的所有表 func ListTables(ctx context.Context, req *mcp.CallToolRequest, input ListTablesInput) (*mcp.CallToolResult, any, error) { rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") if err != nil { return errorResult(fmt.Sprintf("Failed to list tables: %v", err)) } defer rows.Close() var tables []string for rows.Next() { var name string rows.Scan(&name) tables = append(tables, name) } result, _ := json.MarshalIndent(tables, "", " ") return textResult(fmt.Sprintf("Tables in database:\n%s", string(result))) } // QueryDB - 执行 SQL 查询 func QueryDB(ctx context.Context, req *mcp.CallToolRequest, input QueryInput) (*mcp.CallToolResult, any, error) { // 安全检查:只允许 SELECT 查询 sqlUpper := strings.ToUpper(strings.TrimSpace(input.SQL)) if !strings.HasPrefix(sqlUpper, "SELECT") { return errorResult("Only SELECT queries are allowed for safety") } rows, err := db.Query(input.SQL) if err != nil { return errorResult(fmt.Sprintf("Query failed: %v", err)) } defer rows.Close() // 获取列名 columns, _ := rows.Columns() // 读取所有行 var results []map[string]interface{} for rows.Next() { values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { valuePtrs[i] = &values[i] } rows.Scan(valuePtrs...) row := make(map[string]interface{}) for i, col := range columns { row[col] = values[i] } results = append(results, row) } jsonResult, _ := json.MarshalIndent(results, "", " ") return textResult(fmt.Sprintf("Query results (%d rows):\n%s", len(results), string(jsonResult))) } // GetUser - 获取单个用户 func GetUser(ctx context.Context, req *mcp.CallToolRequest, input GetUserInput) (*mcp.CallToolResult, any, error) { var id int var name, email string var age int var createdAt string err := db.QueryRow("SELECT id, name, email, age, created_at FROM users WHERE id = ?", input.ID). Scan(&id, &name, &email, &age, &createdAt) if err == sql.ErrNoRows { return errorResult(fmt.Sprintf("User with ID %d not found", input.ID)) } if err != nil { return errorResult(fmt.Sprintf("Query failed: %v", err)) } user := map[string]interface{}{ "id": id, "name": name, "email": email, "age": age, "created_at": createdAt, } jsonResult, _ := json.MarshalIndent(user, "", " ") return textResult(string(jsonResult)) } // AddUser - 添加用户 func AddUser(ctx context.Context, req *mcp.CallToolRequest, input AddUserInput) (*mcp.CallToolResult, any, error) { result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)", input.Name, input.Email, input.Age) if err != nil { return errorResult(fmt.Sprintf("Failed to add user: %v", err)) } id, _ := result.LastInsertId() return textResult(fmt.Sprintf("User added successfully with ID: %d", id)) } // DeleteUser - 删除用户 func DeleteUser(ctx context.Context, req *mcp.CallToolRequest, input DeleteUserInput) (*mcp.CallToolResult, any, error) { result, err := db.Exec("DELETE FROM users WHERE id = ?", input.ID) if err != nil { return errorResult(fmt.Sprintf("Failed to delete user: %v", err)) } affected, _ := result.RowsAffected() if affected == 0 { return errorResult(fmt.Sprintf("User with ID %d not found", input.ID)) } return textResult(fmt.Sprintf("User with ID %d deleted successfully", input.ID)) } // ==================== 主函数 ==================== func main() { // 初始化数据库 if err := initDB(); err != nil { log.Fatalf("Failed to initialize database: %v", err) } defer db.Close() // 创建 MCP Server server := mcp.NewServer( &mcp.Implementation{ Name: "mcp-demo-go", Version: "1.0.0", }, nil, ) // 注册工具 mcp.AddTool(server, &mcp.Tool{ Name: "list_tables", Description: "List all tables in the SQLite database", }, ListTables) mcp.AddTool(server, &mcp.Tool{ Name: "query_db", Description: "Execute a SELECT SQL query on the database", }, QueryDB) mcp.AddTool(server, &mcp.Tool{ Name: "get_user", Description: "Get a user by ID from the users table", }, GetUser) mcp.AddTool(server, &mcp.Tool{ Name: "add_user", Description: "Add a new user to the database", }, AddUser) mcp.AddTool(server, &mcp.Tool{ Name: "delete_user", Description: "Delete a user from the database by ID", }, DeleteUser) // 运行服务器 log.Println("MCP Demo Go Server (SQLite) is running...") if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil { log.Fatalf("Server error: %v", err) } }