Files
mcp-demo/golang/main.go
CHE LIANG ZHAO 6e8a93c8e9 Initial commit
2026-01-16 18:21:32 +08:00

294 lines
7.3 KiB
Go

package main
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"log"
"os"
"path/filepath"
"strings"
_ "github.com/mattn/go-sqlite3"
"github.com/modelcontextprotocol/go-sdk/mcp"
)
var db *sql.DB
// ==================== 数据库初始化 ====================
func initDB() error {
// 获取可执行文件所在目录
execPath, err := os.Executable()
if err != nil {
execPath = "."
}
dbPath := filepath.Join(filepath.Dir(execPath), "demo.db")
// 如果数据库不存在,使用当前目录
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
dbPath = "demo.db"
}
db, err = sql.Open("sqlite3", dbPath)
if err != nil {
return err
}
// 创建示例表
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
email TEXT UNIQUE,
age INTEGER,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS products (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
price REAL,
stock INTEGER DEFAULT 0
);
`)
if err != nil {
return err
}
// 插入示例数据(如果表为空)
var count int
db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
if count == 0 {
_, err = db.Exec(`
INSERT INTO users (name, email, age) VALUES
('张三', 'zhangsan@example.com', 28),
('李四', 'lisi@example.com', 32),
('王五', 'wangwu@example.com', 25);
INSERT INTO products (name, price, stock) VALUES
('iPhone 15', 6999.00, 100),
('MacBook Pro', 14999.00, 50),
('AirPods Pro', 1899.00, 200);
`)
if err != nil {
return err
}
}
return nil
}
// ==================== 工具输入结构 ====================
// ListTablesInput - 列出所有表
type ListTablesInput struct{}
// QueryInput - 执行 SQL 查询
type QueryInput struct {
SQL string `json:"sql" mcp:"SQL query to execute (SELECT only for safety)"`
}
// GetUserInput - 获取用户
type GetUserInput struct {
ID int `json:"id" mcp:"User ID to retrieve"`
}
// AddUserInput - 添加用户
type AddUserInput struct {
Name string `json:"name" mcp:"User name"`
Email string `json:"email" mcp:"User email address"`
Age int `json:"age" mcp:"User age"`
}
// DeleteUserInput - 删除用户
type DeleteUserInput struct {
ID int `json:"id" mcp:"User ID to delete"`
}
// ==================== 辅助函数 ====================
func textResult(text string) (*mcp.CallToolResult, any, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: text},
},
}, nil, nil
}
func errorResult(text string) (*mcp.CallToolResult, any, error) {
return &mcp.CallToolResult{
Content: []mcp.Content{
&mcp.TextContent{Text: text},
},
IsError: true,
}, nil, nil
}
// ==================== 工具实现 ====================
// ListTables - 列出数据库中的所有表
func ListTables(ctx context.Context, req *mcp.CallToolRequest, input ListTablesInput) (*mcp.CallToolResult, any, error) {
rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
if err != nil {
return errorResult(fmt.Sprintf("Failed to list tables: %v", err))
}
defer rows.Close()
var tables []string
for rows.Next() {
var name string
rows.Scan(&name)
tables = append(tables, name)
}
result, _ := json.MarshalIndent(tables, "", " ")
return textResult(fmt.Sprintf("Tables in database:\n%s", string(result)))
}
// QueryDB - 执行 SQL 查询
func QueryDB(ctx context.Context, req *mcp.CallToolRequest, input QueryInput) (*mcp.CallToolResult, any, error) {
// 安全检查:只允许 SELECT 查询
sqlUpper := strings.ToUpper(strings.TrimSpace(input.SQL))
if !strings.HasPrefix(sqlUpper, "SELECT") {
return errorResult("Only SELECT queries are allowed for safety")
}
rows, err := db.Query(input.SQL)
if err != nil {
return errorResult(fmt.Sprintf("Query failed: %v", err))
}
defer rows.Close()
// 获取列名
columns, _ := rows.Columns()
// 读取所有行
var results []map[string]interface{}
for rows.Next() {
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
valuePtrs[i] = &values[i]
}
rows.Scan(valuePtrs...)
row := make(map[string]interface{})
for i, col := range columns {
row[col] = values[i]
}
results = append(results, row)
}
jsonResult, _ := json.MarshalIndent(results, "", " ")
return textResult(fmt.Sprintf("Query results (%d rows):\n%s", len(results), string(jsonResult)))
}
// GetUser - 获取单个用户
func GetUser(ctx context.Context, req *mcp.CallToolRequest, input GetUserInput) (*mcp.CallToolResult, any, error) {
var id int
var name, email string
var age int
var createdAt string
err := db.QueryRow("SELECT id, name, email, age, created_at FROM users WHERE id = ?", input.ID).
Scan(&id, &name, &email, &age, &createdAt)
if err == sql.ErrNoRows {
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
}
if err != nil {
return errorResult(fmt.Sprintf("Query failed: %v", err))
}
user := map[string]interface{}{
"id": id,
"name": name,
"email": email,
"age": age,
"created_at": createdAt,
}
jsonResult, _ := json.MarshalIndent(user, "", " ")
return textResult(string(jsonResult))
}
// AddUser - 添加用户
func AddUser(ctx context.Context, req *mcp.CallToolRequest, input AddUserInput) (*mcp.CallToolResult, any, error) {
result, err := db.Exec("INSERT INTO users (name, email, age) VALUES (?, ?, ?)",
input.Name, input.Email, input.Age)
if err != nil {
return errorResult(fmt.Sprintf("Failed to add user: %v", err))
}
id, _ := result.LastInsertId()
return textResult(fmt.Sprintf("User added successfully with ID: %d", id))
}
// DeleteUser - 删除用户
func DeleteUser(ctx context.Context, req *mcp.CallToolRequest, input DeleteUserInput) (*mcp.CallToolResult, any, error) {
result, err := db.Exec("DELETE FROM users WHERE id = ?", input.ID)
if err != nil {
return errorResult(fmt.Sprintf("Failed to delete user: %v", err))
}
affected, _ := result.RowsAffected()
if affected == 0 {
return errorResult(fmt.Sprintf("User with ID %d not found", input.ID))
}
return textResult(fmt.Sprintf("User with ID %d deleted successfully", input.ID))
}
// ==================== 主函数 ====================
func main() {
// 初始化数据库
if err := initDB(); err != nil {
log.Fatalf("Failed to initialize database: %v", err)
}
defer db.Close()
// 创建 MCP Server
server := mcp.NewServer(
&mcp.Implementation{
Name: "mcp-demo-go",
Version: "1.0.0",
},
nil,
)
// 注册工具
mcp.AddTool(server, &mcp.Tool{
Name: "list_tables",
Description: "List all tables in the SQLite database",
}, ListTables)
mcp.AddTool(server, &mcp.Tool{
Name: "query_db",
Description: "Execute a SELECT SQL query on the database",
}, QueryDB)
mcp.AddTool(server, &mcp.Tool{
Name: "get_user",
Description: "Get a user by ID from the users table",
}, GetUser)
mcp.AddTool(server, &mcp.Tool{
Name: "add_user",
Description: "Add a new user to the database",
}, AddUser)
mcp.AddTool(server, &mcp.Tool{
Name: "delete_user",
Description: "Delete a user from the database by ID",
}, DeleteUser)
// 运行服务器
log.Println("MCP Demo Go Server (SQLite) is running...")
if err := server.Run(context.Background(), &mcp.StdioTransport{}); err != nil {
log.Fatalf("Server error: %v", err)
}
}