class RisingWaveMCPAgent:
"""Agent for interacting with RisingWave via MCP and Anthropic."""
def __init__(self, server_script_path: str):
"""Initialize the agent and set up environment variables."""
self.client = Client(server_script_path)
env = {
"RISINGWAVE_HOST": os.getenv("RISINGWAVE_HOST", "0.0.0.0"),
"RISINGWAVE_USER": os.getenv("RISINGWAVE_USER", "root"),
"RISINGWAVE_PASSWORD": os.getenv("RISINGWAVE_PASSWORD", "root"),
"RISINGWAVE_PORT": os.getenv("RISINGWAVE_PORT", "4566"),
"RISINGWAVE_DATABASE": os.getenv("RISINGWAVE_DATABASE", "dev"),
"RISINGWAVE_SSLMODE": os.getenv("RISINGWAVE_SSLMODE", "disable")
}
self.client.transport.env = env
self.anthropic = Anthropic()
self.conversation = []
self._tools_cache = None
self.table_cache = set()
self.mv_cache = set()
self.table_descriptions = {}
self.mv_descriptions = {}
self._init_done = False
async def initialize_caches(self):
"""Initialize and cache table and materialized view names and their descriptions."""
if self._init_done:
return
# Get table names
try:
tables_result = await self.client.call_tool("show_tables", {})
tables = json.loads(tables_result) if isinstance(tables_result, str) else tables_result
if isinstance(tables, list):
for t in tables:
name = t[0] if isinstance(t, list) else t.get("table_name") or t.get("name")
if name:
self.table_cache.add(name)
# Cache description
try:
desc = await self.client.call_tool(
"describe_table", {"table_name": name}
)
self.table_descriptions[name] = desc
except Exception:
pass
except Exception:
pass
# Get materialized view names
try:
mvs_result = await self.client.call_tool("list_materialized_views", {})
mvs = json.loads(mvs_result) if isinstance(mvs_result, str) else mvs_result
if isinstance(mvs, list):
for mv in mvs:
name = mv[0] if isinstance(mv, list) else mv.get("name") or mv.get("mv_name")
if name:
self.mv_cache.add(name)
# Cache description
try:
desc = await self.client.call_tool(
"describe_materialized_view", {"mv_name": name}
)
self.mv_descriptions[name] = desc
except Exception:
pass
except Exception:
pass
self._init_done = True
async def list_tools(self):
"""List available tools from the client."""
if self._tools_cache is None:
tools = await self.client.list_tools()
self._tools_cache = [{
"name": tool.name,
"description": tool.description,
"input_schema": tool.inputSchema
} for tool in tools]
return self._tools_cache
async def __aenter__(self):
"""Async context manager entry."""
await self.client.__aenter__()
return self
async def __aexit__(self, exc_type, exc, tb):
"""Async context manager exit."""
await self.client.__aexit__(exc_type, exc, tb)