Skip to content

Commit 548825b

Browse files
authored
docs: add text2sql example for pytidb (#17)
* docs: add text2sql example for pytidb * docs: fix
1 parent 5890e6b commit 548825b

File tree

4 files changed

+199
-0
lines changed

4 files changed

+199
-0
lines changed

examples/text2sql/README.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Streamlit Examples
2+
3+
* Use `pytidb` to connect to TiDB
4+
* Use Streamlit as web ui
5+
6+
7+
## Prerequisites
8+
* Python 3.8+
9+
* OpenAI API key
10+
* TiDB server connection string, either local or TiDB Cloud
11+
12+
13+
## How to run
14+
15+
**Step0**: Clone the repo
16+
17+
```bash
18+
git clone https://github.com/pingcap/pytidb.git
19+
cd pytidb/examples/text2sql/;
20+
```
21+
22+
**Step1**: Install the required packages
23+
24+
```bash
25+
python -m venv .venv
26+
source .venv/bin/activate
27+
pip install -r reqs.txt
28+
```
29+
30+
**Step2**: Run the Streamlit app
31+
32+
```bash
33+
streamlit run main.py
34+
```
35+
36+
**Step3**: Open the browser and visit `http://localhost:8501`
37+
38+
* Input OpenAI API key in left sidebar
39+
* Input the TiDB Cloud connection string in left sidebar, the format is `mysql+pymysql://root@localhost:4000/test`

examples/text2sql/main.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import os
4+
5+
import dotenv
6+
from openai import OpenAI
7+
import streamlit as st
8+
from pytidb import TiDBClient
9+
from pydantic import BaseModel
10+
11+
dotenv.load_dotenv()
12+
13+
14+
class QuestionSQLResponse(BaseModel):
15+
question: str
16+
sql: str
17+
markdown: str
18+
19+
20+
st.set_page_config(page_title="Text2SQL", page_icon="📖", layout="wide")
21+
with st.sidebar:
22+
st.markdown("# Text2SQL")
23+
st.markdown(
24+
"## How to use\n"
25+
"1. Enter your [OpenAI API key](https://platform.openai.com/account/api-keys) below 🔑\n" # noqa: E501
26+
"2. Enter your [TiDB Cloud](https://tidbcloud.com) database connection URL below 🔗\n"
27+
"3. Ask a question in the right chat boxgh 🤖\n"
28+
)
29+
st.warning(
30+
"Please double check the generated SQL before running it on your database."
31+
)
32+
openai_api_key_input = st.text_input(
33+
"OpenAI API Key",
34+
type="password",
35+
placeholder="Paste your OpenAI API key here (sk-...)",
36+
help="You can get your API key from https://platform.openai.com/account/api-keys.", # noqa: E501
37+
value=os.environ.get("OPENAI_API_KEY", None)
38+
or st.session_state.get("OPENAI_API_KEY", ""),
39+
)
40+
database_url_input = st.text_input(
41+
"Database URL",
42+
type="password",
43+
placeholder="e.g. mysql+pymysql://root@localhost:4000/test",
44+
autocomplete="off",
45+
help="You can get your database URL from https://tidbcloud.com",
46+
value=os.environ.get("DATABASE_URL", None)
47+
or "mysql+pymysql://root@localhost:4000/test"
48+
or st.session_state.get("DATABASE_URL", ""),
49+
)
50+
st.session_state["OPENAI_API_KEY"] = openai_api_key_input
51+
st.session_state["DATABASE_URL"] = database_url_input
52+
53+
openai_api_key = st.session_state.get("OPENAI_API_KEY")
54+
database_url = st.session_state.get("DATABASE_URL")
55+
56+
if not openai_api_key or not database_url:
57+
st.error("Please enter your OpenAI API key and TiDB Cloud connection string.")
58+
st.stop()
59+
60+
db = TiDBClient.connect(database_url)
61+
oai = OpenAI(api_key=openai_api_key)
62+
63+
for item in ["generated", "past"]:
64+
if item not in st.session_state:
65+
st.session_state[item] = []
66+
67+
table_definitions = []
68+
current_database = db._db_engine.url.database
69+
for table_name in db.table_names():
70+
table_definitions.append(
71+
db.query(f"SHOW CREATE TABLE `{table_name}`").to_rows()[0]
72+
)
73+
74+
75+
def on_submit():
76+
user_input = st.session_state.user_input
77+
if user_input:
78+
response = (
79+
oai.beta.chat.completions.parse(
80+
model="gpt-4o-mini",
81+
messages=[
82+
{
83+
"role": "system",
84+
"content": f"""
85+
You are a very senior database administrator who can write SQL very well,
86+
please write MySQL SQL to answer user question,
87+
Use backticks to quote table names and column names,
88+
here are some table definitions in database,
89+
the database name is {current_database}\n\n"""
90+
+ "\n".join("|".join(t) for t in table_definitions),
91+
},
92+
{"role": "user", "content": f"Question: {user_input}\n"},
93+
],
94+
response_format=QuestionSQLResponse,
95+
)
96+
.choices[0]
97+
.message.parsed
98+
)
99+
st.session_state.past.append(user_input)
100+
101+
if 'insert' in response.sql.lower() or 'update' in response.sql.lower():
102+
st.error(
103+
"The generated SQL is not a SELECT statement, please check it carefully before running it."
104+
)
105+
st.stop()
106+
107+
# Execute the SQL query and set the result
108+
answer = None
109+
try:
110+
rows = db.query(response.sql).to_rows()
111+
sql_result = "\n".join(str(row) for row in rows)
112+
113+
answer = (
114+
oai.chat.completions.create(
115+
model="gpt-4o-mini",
116+
messages=[
117+
{
118+
"role": "system",
119+
"content": "You are a markdown formatter, format the user input to markdown, format the data row into markdown tables.",
120+
},
121+
{
122+
"role": "user",
123+
"content": f"""
124+
Question: {response.question}\n\n
125+
SQL: {response.sql}\n\n
126+
Markdown: {response.markdown}\n\n
127+
Result: {sql_result}""",
128+
},
129+
],
130+
)
131+
.choices[0]
132+
.message.content
133+
)
134+
st.session_state.generated.append(answer)
135+
except Exception as e:
136+
st.session_state.generated.append(f"Error: {e}")
137+
138+
139+
st.markdown("##### User Query")
140+
with st.container():
141+
st.chat_input(
142+
"Input your question, e.g. how many tables?",
143+
key="user_input",
144+
on_submit=on_submit,
145+
)
146+
147+
chat_placeholder = st.empty()
148+
with chat_placeholder.container():
149+
for i in range(len(st.session_state["generated"]) - 1, -1, -1):
150+
with st.chat_message("user"):
151+
st.write(st.session_state["past"][i])
152+
with st.chat_message("assistant"):
153+
st.write(st.session_state["generated"][i])

examples/text2sql/reqs.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
openai
2+
pytidb
3+
pymysql
4+
streamlit
5+
httpx[socks]
6+
python-dotenv

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ exclude = [
6262
"/.*",
6363
"/dist",
6464
"/docs",
65+
"/examples",
6566
"/tests",
6667
]
6768

0 commit comments

Comments
 (0)