diff --git a/edu/src/vector-db.md b/edu/src/vector-db.md index 23a9263..50b117a 100644 --- a/edu/src/vector-db.md +++ b/edu/src/vector-db.md @@ -577,7 +577,244 @@ async fn main() -> Result<(), Box> { ### 8. Exercise 2 โ€” K-Nearest Neighbor Search -**Goal:** Use `vector_top_k` and `vector_distance_cos` to find the *k* vectors in the database most similar to a query vector, and display the results ranked by similarity score. ๐Ÿšง Full content tracked in [nbd:5674ce]. +**Goal:** Given a query vector, use `vector_top_k` to find the 3 most similar items, join with the `items` table to retrieve labels and exact cosine distances, and display the results ranked by distance. + +#### Step 1 โ€” Introduce `vector_top_k` + +`vector_top_k` is a **table-valued function** (TVF) that returns row IDs of approximate nearest neighbours without performing a full table scan. It leverages the HNSW index created in ยง6 to navigate directly to the neighbourhood of the query vector. The syntax is: + +```sql +SELECT i.rowid FROM vector_top_k('items', vector(?), ?) i +``` + +The three arguments are: + +1. **Table name** (string literal) โ€” the table whose vector index should be searched. +2. **Query vector** โ€” passed through `vector()` as a JSON array string, just like when inserting data. +3. **k** โ€” the number of nearest neighbours to return. + +The function returns only `rowid` values โ€” it does not return labels, embeddings, or distances. To access other columns you must JOIN the result back to the original table. This design keeps the TVF focused on index traversal and lets you choose exactly which columns to retrieve. + +#### Step 2 โ€” Full KNN Query + +Combine the TVF with a JOIN and an exact distance computation to get labelled, ranked results: + +```sql +SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist +FROM vector_top_k('items', vector(?), ?) AS knn +JOIN items ON items.rowid = knn.rowid +ORDER BY dist ASC +``` + +Notice that the **query vector must be passed twice** โ€” once as the second argument to `vector_top_k` (for index traversal to find candidate rows) and once as the second argument to `vector_distance_cos` (for exact distance computation on those candidates). Both are the same JSON array string bound to separate SQL parameters. + +Why two passes? `vector_top_k` uses the HNSW index to quickly identify *which* rows are likely nearest neighbours, but it does not return distance values. `vector_distance_cos` then computes the exact cosine distance for each candidate row, which you use for ranking and display. + +#### Step 3 โ€” Run Three Queries and Print Results + +Define a helper function that runs the KNN query for a given query vector and prints the results: + +```rust +async fn knn_query( + conn: &libsql::Connection, + query: &[f32], + k: i32, +) -> Result<(), Box> { + let q = vec_to_json(query); + let mut rows = conn + .query( + "SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist + FROM vector_top_k('items', vector(?), ?) AS knn + JOIN items ON items.rowid = knn.rowid + ORDER BY dist ASC", + libsql::params![q.clone(), q.clone(), k], + ) + .await?; + + println!("Query: {q}"); + let mut rank = 1; + while let Some(row) = rows.next().await? { + let label: String = row.get(1)?; + let dist: f64 = row.get(2)?; + println!(" {rank}. {label:<10} dist={dist:.4}"); + rank += 1; + } + println!(); + Ok(()) +} +``` + +Run three queries, each probing one of the three clusters from the ยง7 dataset: + +```rust +// Animal cluster +knn_query(&conn, &[0.85, 0.15, 0.25], 3).await?; + +// Vehicle cluster +knn_query(&conn, &[0.15, 0.85, 0.15], 3).await?; + +// Language cluster +knn_query(&conn, &[0.1, 0.05, 0.92], 3).await?; +``` + +Expected output (exact distances depend on floating-point precision): + +``` +Query: [0.85,0.15,0.25] + 1. cat dist=0.0023 + 2. dog dist=0.0089 + 3. python dist=0.1834 + +Query: [0.15,0.85,0.15] + 1. car dist=0.0006 + 2. truck dist=0.0030 + 3. cat dist=0.3885 + +Query: [0.1,0.05,0.92] + 1. rust dist=0.0003 + 2. python dist=0.0016 + 3. dog dist=0.2197 +``` + +Each query correctly identifies the two items in its target cluster as the closest matches. The third result is always from a different cluster with a noticeably larger distance. + +#### Step 4 โ€” ANN vs. Exact Search + +For the 6-row dataset used in these exercises, `vector_top_k` falls back to **exact search** โ€” the HNSW index has too few nodes to offer a meaningful shortcut, so the algorithm examines every vector. The results are identical to brute-force KNN. + +At scale โ€” millions of rows โ€” `vector_top_k` returns **approximate** results. The HNSW index navigates the graph greedily, which means some true nearest neighbours may be missed if they are poorly connected in the graph. This is the recall-vs-speed trade-off discussed in ยง5: the index answers queries in milliseconds instead of seconds, but recall@k is typically ~0.95 rather than 1.0. + +`vector_distance_cos`, by contrast, always gives the **exact** cosine distance for any specific pair of vectors. It is a pure computation with no approximation. The approximation lives only in the *selection* of which candidates to evaluate โ€” that is the job of the index. + +In practice this means: trust `vector_top_k` for fast retrieval, but understand that at scale a small fraction of true nearest neighbours may not appear in the result set. If perfect recall is required, you can increase the index's `ef_search` parameter (when exposed by the engine) or fall back to brute-force search over a filtered subset. + +#### Reference Solution + +
Show full solution + +**`Cargo.toml`** (dependencies only): + +```toml +[dependencies] +libsql = "0.9" +tokio = { version = "1", features = ["full"] } +serde_json = "1" +``` + +**`src/main.rs`**: + +```rust +use libsql::{Builder, Database}; + +/// Convert a float slice to a JSON array string for sqlite-vec's `vector()` function. +fn vec_to_json(v: &[f32]) -> String { + format!( + "[{}]", + v.iter() + .map(|x| x.to_string()) + .collect::>() + .join(",") + ) +} + +/// Run a KNN query and print the top-k results with labels and distances. +async fn knn_query( + conn: &libsql::Connection, + query: &[f32], + k: i32, +) -> Result<(), Box> { + let q = vec_to_json(query); + let mut rows = conn + .query( + "SELECT items.id, items.label, vector_distance_cos(items.embedding, vector(?)) AS dist + FROM vector_top_k('items', vector(?), ?) AS knn + JOIN items ON items.rowid = knn.rowid + ORDER BY dist ASC", + libsql::params![q.clone(), q.clone(), k], + ) + .await?; + + println!("Query: {q}"); + let mut rank = 1; + while let Some(row) = rows.next().await? { + let label: String = row.get(1)?; + let dist: f64 = row.get(2)?; + println!(" {rank}. {label:<10} dist={dist:.4}"); + rank += 1; + } + println!(); + Ok(()) +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // --- Open database --- + let db: Database = Builder::new_local("vectors.db").build().await?; + let conn = db.connect()?; + + // Verify connection + let mut rows = conn.query("SELECT sqlite_version()", ()).await?; + if let Some(row) = rows.next().await? { + let version: String = row.get(0)?; + println!("SQLite version: {version}"); + } + + // --- Create table (from ยง6) --- + conn.execute( + "CREATE TABLE IF NOT EXISTS items ( + id INTEGER PRIMARY KEY, + label TEXT NOT NULL, + embedding F32_BLOB(3) NOT NULL + )", + (), + ) + .await?; + + // --- Create HNSW index (from ยง6) --- + conn.execute( + "CREATE INDEX IF NOT EXISTS items_vec_idx + ON items (embedding) + USING libsql_vector_idx(embedding)", + (), + ) + .await?; + + println!("Database ready."); + + // --- Insert 6 labelled vectors (from ยง7) --- + let data: Vec<(i64, &str, Vec)> = vec![ + (1, "cat", vec![0.9, 0.1, 0.2]), + (2, "dog", vec![0.8, 0.2, 0.3]), + (3, "car", vec![0.1, 0.9, 0.1]), + (4, "truck", vec![0.2, 0.8, 0.2]), + (5, "python", vec![0.15, 0.1, 0.95]), + (6, "rust", vec![0.1, 0.05, 0.9]), + ]; + + for (id, label, embedding) in &data { + conn.execute( + "INSERT OR IGNORE INTO items (id, label, embedding) VALUES (?, ?, vector(?))", + libsql::params![*id, *label, vec_to_json(embedding)], + ) + .await?; + } + println!("Inserted {} rows.", data.len()); + + // --- KNN queries --- + // Animal cluster + knn_query(&conn, &[0.85, 0.15, 0.25], 3).await?; + + // Vehicle cluster + knn_query(&conn, &[0.15, 0.85, 0.15], 3).await?; + + // Language cluster + knn_query(&conn, &[0.1, 0.05, 0.92], 3).await?; + + Ok(()) +} +``` + +
---