docs(edu): write §12 exercise 5 RAG pipeline for vector-db course [5ed295]

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
main
Elijah Voigt 3 months ago
parent 60c9fb67a8
commit e91bdd31ec

@ -1463,4 +1463,401 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
### 12. Exercise 5 — Retrieval-Augmented Generation
**Goal:** Combine vector search with a language model to build a retrieval-augmented generation (RAG) pipeline: given a user question, retrieve the most relevant passages from a document store using semantic search, inject them into a prompt as context, and stream the language model's grounded answer back to the user. 🚧 Full content tracked in [nbd:5ed295].
**Goal:** Build a retrieval-augmented generation (RAG) pipeline that:
1. Stores the 15-passage corpus from §10 in Turso
2. Accepts a natural-language question
3. Retrieves the top-3 most relevant passages using vector KNN
4. Injects the passages into a prompt as context
5. Sends the prompt to an OpenAI-compatible LLM API
6. Prints the grounded answer
**Setup:**
```toml
[dependencies]
libsql = "0.9"
fastembed = "4"
reqwest = { version = "0.12", features = ["json"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["full"] }
```
You will need an API key stored in the `OPENAI_API_KEY` environment variable. This exercise works with any OpenAI-compatible provider — OpenAI itself, Groq, Together AI, or a local Ollama instance (base URL `http://localhost:11434/v1`, model `llama3.2`). Adjust the base URL and model name accordingly if you are not using OpenAI.
#### Step 1 — Retrieval function
Reuse the semantic search logic from §10. Write a function that embeds the query, runs a KNN search, and returns the top-k passage texts:
```rust
async fn retrieve(
conn: &libsql::Connection,
model: &TextEmbedding,
query: &str,
k: usize,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage
FROM vector_top_k('docs_idx', vector(?), ?) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str(), k as i64],
)
.await?;
let mut passages = Vec::new();
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
passages.push(passage);
}
Ok(passages)
}
```
#### Step 2 — Prompt construction
Build a prompt string that instructs the model to answer using only the retrieved context:
```rust
fn build_prompt(context_passages: &[String], question: &str) -> String {
let mut prompt = String::from(
"You are a helpful assistant. Answer the question using only the provided context.\n\
If the context does not contain enough information, say so.\n\n\
Context:\n",
);
for passage in context_passages {
prompt.push_str(passage);
prompt.push_str("\n\n");
}
prompt.push_str(&format!("Question: {question}\n\nAnswer:"));
prompt
}
```
#### Step 3 — LLM API call
POST to the chat completions endpoint. Define request and response structs with serde, then send the prompt as a user message:
```rust
#[derive(serde::Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
}
#[derive(serde::Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(serde::Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(serde::Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(serde::Deserialize)]
struct ResponseMessage {
content: String,
}
async fn call_llm(
client: &reqwest::Client,
api_key: &str,
prompt: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let request = ChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
};
let resp = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(api_key)
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ChatResponse>()
.await?;
Ok(resp.choices[0].message.content.clone())
}
```
#### Step 4 — Wire it together and run
Set up the database and corpus exactly as in §10, then run three example questions that exercise each topic cluster:
```rust
let questions = vec![
"How does Rust ensure memory safety?",
"What is a black hole?",
"What is the Maillard reaction?",
];
let client = reqwest::Client::new();
let api_key = std::env::var("OPENAI_API_KEY")?;
for question in &questions {
println!("=== Question: \"{question}\" ===\n");
let passages = retrieve(&conn, &model, question, 3).await?;
println!("Retrieved passages:");
for (i, p) in passages.iter().enumerate() {
println!(" {}: {p}", i + 1);
}
println!();
let prompt = build_prompt(&passages, question);
let answer = call_llm(&client, &api_key, &prompt).await?;
println!("Answer: {answer}\n");
}
```
Each question should pull passages from the matching cluster — Rust passages for the first, astronomy for the second, and cooking for the third. The LLM's answer will be grounded in those passages rather than relying on its own parametric knowledge.
#### Step 5 — Discussion: RAG patterns
**Chunk size and overlap.** The 15-passage corpus used here is already conveniently pre-chunked into single sentences, but real documents are rarely so tidy. In practice, long documents are split into overlapping chunks — typically 200500 tokens with a 50100 token overlap between consecutive chunks. The overlap ensures that sentences near a chunk boundary are not orphaned from their surrounding context, which would hurt retrieval quality. Choosing the right chunk size is a trade-off: smaller chunks yield more precise retrieval but lose broader context, while larger chunks retain context at the cost of noisier matches.
**Re-ranking.** The ANN index returns approximate nearest neighbors quickly, but the ranking is based on a single embedding similarity score. A cross-encoder re-ranker — a model that takes (query, passage) pairs as input and produces a relevance score — can re-order the top-k candidates for significantly better precision. The typical pattern is to retrieve a larger set (e.g., top-20) with ANN and then re-rank to the final top-3 or top-5 with the cross-encoder.
**Hybrid search.** Semantic (ANN) search excels at matching meaning but can miss exact keywords, while keyword-based search (BM25) is great at exact term matching but blind to synonyms. Combining both — often called hybrid search — frequently outperforms either approach alone. A common fusion strategy is Reciprocal Rank Fusion (RRF), which merges the two ranked lists by summing the reciprocal of each result's rank.
**Context window limits.** The number of passages you can inject depends on the model's context length and the average passage length. GPT-4o-mini supports 128k tokens, but stuffing the entire context window with retrieved passages introduces noise and increases latency and cost. A good heuristic is to inject only enough passages to cover the question — typically 3 to 5 short passages or 1 to 2 longer chunks — and to place the most relevant passages first.
#### Reference Solution
<details>
<summary>Show full solution</summary>
```rust
// src/main.rs — Retrieval-Augmented Generation (Exercise 5)
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use libsql::Builder;
fn vec_to_json(v: &[f32]) -> String {
let parts: Vec<String> = v.iter().map(|x| format!("{x}")).collect();
format!("[{}]", parts.join(","))
}
/// Retrieve the top-k passages most relevant to `query` using vector KNN.
async fn retrieve(
conn: &libsql::Connection,
model: &TextEmbedding,
query: &str,
k: usize,
) -> Result<Vec<String>, Box<dyn std::error::Error>> {
let q_emb = model.embed(vec![query.to_string()], None)?;
let q_json = vec_to_json(&q_emb[0]);
let mut rows = conn
.query(
"SELECT d.passage
FROM vector_top_k('docs_idx', vector(?), ?) AS v
JOIN docs AS d ON d.rowid = v.id
ORDER BY v.distance",
libsql::params![q_json.as_str(), k as i64],
)
.await?;
let mut passages = Vec::new();
while let Some(row) = rows.next().await? {
let passage: String = row.get(0)?;
passages.push(passage);
}
Ok(passages)
}
/// Build a RAG prompt that instructs the model to answer from context only.
fn build_prompt(context_passages: &[String], question: &str) -> String {
let mut prompt = String::from(
"You are a helpful assistant. Answer the question using only the provided context.\n\
If the context does not contain enough information, say so.\n\n\
Context:\n",
);
for passage in context_passages {
prompt.push_str(passage);
prompt.push_str("\n\n");
}
prompt.push_str(&format!("Question: {question}\n\nAnswer:"));
prompt
}
#[derive(serde::Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
}
#[derive(serde::Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(serde::Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(serde::Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(serde::Deserialize)]
struct ResponseMessage {
content: String,
}
/// Send the prompt to an OpenAI-compatible chat completions API.
async fn call_llm(
client: &reqwest::Client,
api_key: &str,
prompt: &str,
) -> Result<String, Box<dyn std::error::Error>> {
let request = ChatRequest {
model: "gpt-4o-mini".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: prompt.to_string(),
}],
};
let resp = client
.post("https://api.openai.com/v1/chat/completions")
.bearer_auth(api_key)
.json(&request)
.send()
.await?
.error_for_status()?
.json::<ChatResponse>()
.await?;
Ok(resp.choices[0].message.content.clone())
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// ── 1. Connect to Turso (local file) ──
let db = Builder::new_local("rag_search.db").build().await?;
let conn = db.connect()?;
// ── 2. Create the docs table ──
conn.execute(
"CREATE TABLE IF NOT EXISTS docs (
id INTEGER PRIMARY KEY,
passage TEXT NOT NULL,
embedding F32_BLOB(384) NOT NULL
)",
(),
)
.await?;
// ── 3. Create the vector index ──
conn.execute(
"CREATE INDEX IF NOT EXISTS docs_idx ON docs (libsql_vector_idx(embedding))",
(),
)
.await?;
// ── 4. Define the corpus ──
let passages: Vec<String> = vec![
// Rust programming
"Rust uses an ownership system to guarantee memory safety without a garbage collector.",
"The borrow checker enforces that references do not outlive the data they point to.",
"Cargo is Rust's build system and package manager, used to manage dependencies and run tests.",
"Rust's trait system enables zero-cost abstractions and compile-time polymorphism.",
"Async Rust uses futures and the tokio runtime to handle concurrent I/O efficiently.",
// Astronomy
"A black hole is a region of spacetime where gravity is so strong that nothing can escape.",
"The Milky Way galaxy contains an estimated 100 to 400 billion stars.",
"Neutron stars are the collapsed cores of massive stars, with densities exceeding atomic nuclei.",
"The cosmic microwave background is the thermal radiation left over from the early universe.",
"Exoplanets are planets outside our solar system, detected via transit photometry or radial velocity.",
// Cooking
"Maillard reaction gives browned foods their distinctive flavour through amino acid and sugar reactions.",
"Sous vide cooking involves sealing food in vacuum bags and cooking at precise low temperatures.",
"Emulsification combines two immiscible liquids, such as oil and water, using an emulsifier like lecithin.",
"Fermentation converts sugars to acids or alcohol using microorganisms, used in bread, beer, and yogurt.",
"Knife skills — julienne, brunoise, chiffonade — determine the surface area and cooking time of vegetables.",
]
.into_iter()
.map(String::from)
.collect();
// ── 5. Embed the corpus ──
let model = TextEmbedding::try_new(InitOptions {
model_name: EmbeddingModel::BGESmallENV15,
show_download_progress: true,
..Default::default()
})?;
let embeddings = model.embed(passages.clone(), None)?;
// ── 6. Insert passages + embeddings ──
for (i, (passage, emb)) in passages.iter().zip(embeddings.iter()).enumerate() {
let json = vec_to_json(emb);
conn.execute(
"INSERT OR IGNORE INTO docs (id, passage, embedding) VALUES (?, ?, vector(?))",
libsql::params![i as i64, passage.as_str(), json.as_str()],
)
.await?;
}
println!("Inserted {} passages.\n", passages.len());
// ── 7. RAG pipeline ──
let api_key = std::env::var("OPENAI_API_KEY")?;
let client = reqwest::Client::new();
let questions = vec![
"How does Rust ensure memory safety?",
"What is a black hole?",
"What is the Maillard reaction?",
];
for question in &questions {
println!("=== Question: \"{question}\" ===\n");
let context = retrieve(&conn, &model, question, 3).await?;
println!("Retrieved passages:");
for (i, p) in context.iter().enumerate() {
println!(" {}: {p}", i + 1);
}
println!();
let prompt = build_prompt(&context, question);
let answer = call_llm(&client, &api_key, &prompt).await?;
println!("Answer: {answer}\n");
}
Ok(())
}
```
</details>

Loading…
Cancel
Save