Skip to content

Commit

Permalink
WIP: Initial virtual table implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
jussisaurio committed Jan 19, 2025
1 parent 8767885 commit 5d1cc0f
Show file tree
Hide file tree
Showing 20 changed files with 951 additions and 61 deletions.
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ members = [
"simulator",
"sqlite3",
"test", "extensions/percentile",
"extensions/series",
]
exclude = ["perf/latency/limbo"]

Expand Down
139 changes: 137 additions & 2 deletions core/ext/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
use crate::{function::ExternalFunc, Database};
use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction};
use crate::{
function::ExternalFunc,
schema::{Column, Type},
Database, VirtualTable,
};
use fallible_iterator::FallibleIterator;
use limbo_ext::{ExtensionApi, InitAggFunction, ResultCode, ScalarFunction, VTabModuleImpl};
pub use limbo_ext::{FinalizeFunction, StepFunction, Value as ExtValue, ValueType as ExtValueType};
use sqlite3_parser::{
ast::{Cmd, CreateTableBody, Stmt},
lexer::sql::Parser,
};
use std::{
ffi::{c_char, c_void, CStr},
rc::Rc,
Expand All @@ -13,6 +22,7 @@ unsafe extern "C" fn register_scalar_function(
func: ScalarFunction,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
println!("Scalar??");
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::InvalidArgs,
Expand All @@ -32,6 +42,7 @@ unsafe extern "C" fn register_aggregate_function(
step_func: StepFunction,
finalize_func: FinalizeFunction,
) -> ResultCode {
println!("Aggregate??");
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Expand All @@ -44,6 +55,48 @@ unsafe extern "C" fn register_aggregate_function(
db.register_aggregate_function_impl(&name_str, args, (init_func, step_func, finalize_func))
}

unsafe extern "C" fn register_module(
ctx: *mut c_void,
name: *const c_char,
module: VTabModuleImpl,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};
if ctx.is_null() {
return ResultCode::Error;
}
let db = unsafe { &*(ctx as *const Database) };

db.register_module_impl(&name_str, module)
}

unsafe extern "C" fn declare_vtab(
ctx: *mut c_void,
name: *const c_char,
sql: *const c_char,
) -> ResultCode {
let c_str = unsafe { CStr::from_ptr(name) };
let name_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};

let c_str = unsafe { CStr::from_ptr(sql) };
let sql_str = match c_str.to_str() {
Ok(s) => s.to_string(),
Err(_) => return ResultCode::Error,
};

if ctx.is_null() {
return ResultCode::Error;
}
let db = unsafe { &*(ctx as *const Database) };
db.declare_vtab_impl(&name_str, &sql_str)
}

impl Database {
fn register_scalar_function_impl(&self, name: &str, func: ScalarFunction) -> ResultCode {
self.syms.borrow_mut().functions.insert(
Expand All @@ -66,11 +119,93 @@ impl Database {
ResultCode::OK
}

fn register_module_impl(&self, name: &str, module: VTabModuleImpl) -> ResultCode {
self.vtab_modules
.borrow_mut()
.insert(name.to_string(), Rc::new(module));
ResultCode::OK
}

fn declare_vtab_impl(&self, name: &str, sql: &str) -> ResultCode {
let mut parser = Parser::new(sql.as_bytes());
let cmd = parser.next().unwrap().unwrap();
let Cmd::Stmt(stmt) = cmd else {
return ResultCode::Error;
};
let Stmt::CreateTable { body, .. } = stmt else {
return ResultCode::Error;
};
let CreateTableBody::ColumnsAndConstraints { columns, .. } = body else {
return ResultCode::Error;
};

let columns = columns
.into_iter()
.filter_map(|(name, column_def)| {
// if column_def.col_type includes HIDDEN, omit it for now
if let Some(data_type) = column_def.col_type.as_ref() {
if data_type.name.as_str().contains("HIDDEN") {
return None;
}
}
let column = Column {
name: name.0.clone(),
// TODO extract to util, we use this elsewhere too.
ty: match column_def.col_type {
Some(data_type) => {
// https://www.sqlite.org/datatype3.html
let type_name = data_type.name.as_str().to_uppercase();
if type_name.contains("INT") {
Type::Integer
} else if type_name.contains("CHAR")
|| type_name.contains("CLOB")
|| type_name.contains("TEXT")
{
Type::Text
} else if type_name.contains("BLOB") || type_name.is_empty() {
Type::Blob
} else if type_name.contains("REAL")
|| type_name.contains("FLOA")
|| type_name.contains("DOUB")
{
Type::Real
} else {
Type::Numeric
}
}
None => Type::Null,
},
primary_key: column_def.constraints.iter().any(|c| {
matches!(
c.constraint,
sqlite3_parser::ast::ColumnConstraint::PrimaryKey { .. }
)
}),
is_rowid_alias: false,
};
Some(column)
})
.collect::<Vec<_>>();
let vtab_module = self.vtab_modules.borrow().get(name).unwrap().clone();
let vtab = VirtualTable {
name: name.to_string(),
implementation: vtab_module,
columns,
};
self.syms
.borrow_mut()
.vtabs
.insert(name.to_string(), Rc::new(vtab));
ResultCode::OK
}

pub fn build_limbo_ext(&self) -> ExtensionApi {
ExtensionApi {
ctx: self as *const _ as *mut c_void,
register_scalar_function,
register_aggregate_function,
register_module,
declare_vtab,
}
}
}
102 changes: 99 additions & 3 deletions core/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ use fallible_iterator::FallibleIterator;
#[cfg(not(target_family = "wasm"))]
use libloading::{Library, Symbol};
#[cfg(not(target_family = "wasm"))]
use limbo_ext::{ExtensionApi, ExtensionEntryPoint};
use limbo_ext::{ExtensionApi, ExtensionEntryPoint, ResultCode};
use limbo_ext::{VTabModuleImpl, Value as ExtValue, ValueType};
use log::trace;
use schema::Schema;
use sqlite3_parser::ast;
use schema::{Column, Schema};
use sqlite3_parser::ast::{self};
use sqlite3_parser::{ast::Cmd, lexer::sql::Parser};
use std::cell::Cell;
use std::collections::HashMap;
Expand All @@ -40,8 +41,10 @@ use storage::pager::allocate_page;
use storage::sqlite3_ondisk::{DatabaseHeader, DATABASE_HEADER_SIZE};
pub use storage::wal::WalFile;
pub use storage::wal::WalFileShared;
use types::OwnedValue;
pub use types::Value;
use util::parse_schema_rows;
use vdbe::VTabOpaqueCursor;

pub use error::LimboError;
use translate::select::prepare_select_plan;
Expand Down Expand Up @@ -75,6 +78,7 @@ pub struct Database {
schema: Rc<RefCell<Schema>>,
header: Rc<RefCell<DatabaseHeader>>,
syms: Rc<RefCell<SymbolTable>>,
vtab_modules: Rc<RefCell<HashMap<String, Rc<VTabModuleImpl>>>>,
// Shared structures of a Database are the parts that are common to multiple threads that might
// create DB connections.
_shared_page_cache: Arc<RwLock<DumbLruPageCache>>,
Expand Down Expand Up @@ -137,6 +141,7 @@ impl Database {
_shared_page_cache: _shared_page_cache.clone(),
_shared_wal: shared_wal.clone(),
syms,
vtab_modules: Rc::new(RefCell::new(HashMap::new())),
};
let db = Arc::new(db);
let conn = Rc::new(Connection {
Expand Down Expand Up @@ -509,10 +514,100 @@ impl Rows {
}
}

#[derive(Clone, Debug)]
pub struct VirtualTable {
name: String,
pub implementation: Rc<VTabModuleImpl>,
columns: Vec<Column>,
}

impl VirtualTable {
pub fn open(&self) -> VTabOpaqueCursor {
let cursor = unsafe { (self.implementation.open)() };
VTabOpaqueCursor::new(cursor)
}

pub fn filter(
&self,
cursor: &VTabOpaqueCursor,
arg_count: usize,
args: Vec<OwnedValue>,
) -> Result<()> {
let mut filter_args = Vec::with_capacity(arg_count);
for i in 0..arg_count {
let ownedvalue_arg = args.get(i).unwrap();
let extvalue_arg: ExtValue = match ownedvalue_arg {
OwnedValue::Null => Ok(ExtValue::null()),
OwnedValue::Integer(i) => Ok(ExtValue::from_integer(*i)),
OwnedValue::Float(f) => Ok(ExtValue::from_float(*f)),
OwnedValue::Text(t) => Ok(ExtValue::from_text((*t.value).clone())),
OwnedValue::Blob(b) => Ok(ExtValue::from_blob((**b).clone())),
other => Err(LimboError::ExtensionError(format!(
"Unsupported value type: {:?}",
other
))),
}?;
filter_args.push(extvalue_arg);
}
let rc = unsafe {
(self.implementation.filter)(cursor.as_ptr(), arg_count as i32, filter_args.as_ptr())
};
match rc {
ResultCode::OK => Ok(()),
_ => Err(LimboError::ExtensionError("Filter failed".to_string())),
}
}

pub fn column(&self, cursor: &VTabOpaqueCursor, column: usize) -> Result<OwnedValue> {
let val = unsafe { (self.implementation.column)(cursor.as_ptr(), column as u32) };
match &val.value_type {
ValueType::Null => Ok(OwnedValue::Null),
ValueType::Integer => match val.to_integer() {
Some(i) => Ok(OwnedValue::Integer(i)),
None => Err(LimboError::ExtensionError(
"Failed to convert integer value".to_string(),
)),
},
ValueType::Float => match val.to_float() {
Some(f) => Ok(OwnedValue::Float(f)),
None => Err(LimboError::ExtensionError(
"Failed to convert float value".to_string(),
)),
},
ValueType::Text => match val.to_text() {
Some(t) => Ok(OwnedValue::build_text(Rc::new(t))),
None => Err(LimboError::ExtensionError(
"Failed to convert text value".to_string(),
)),
},
ValueType::Blob => match val.to_blob() {
Some(b) => Ok(OwnedValue::Blob(Rc::new(b))),
None => Err(LimboError::ExtensionError(
"Failed to convert blob value".to_string(),
)),
},
ValueType::Error => Err(LimboError::ExtensionError(format!(
"Error value in column {}",
column
))),
}
}

pub fn next(&self, cursor: &VTabOpaqueCursor) -> Result<bool> {
let rc = unsafe { (self.implementation.next)(cursor.as_ptr()) };
match rc {
ResultCode::OK => Ok(true),
ResultCode::EOF => Ok(false),
_ => Err(LimboError::ExtensionError("Next failed".to_string())),
}
}
}

pub(crate) struct SymbolTable {
pub functions: HashMap<String, Rc<function::ExternalFunc>>,
#[cfg(not(target_family = "wasm"))]
extensions: Vec<(Library, *const ExtensionApi)>,
pub vtabs: HashMap<String, Rc<VirtualTable>>,
}

impl std::fmt::Debug for SymbolTable {
Expand Down Expand Up @@ -554,6 +649,7 @@ impl SymbolTable {
pub fn new() -> Self {
Self {
functions: HashMap::new(),
vtabs: HashMap::new(),
// TODO: wasm libs will be very different
#[cfg(not(target_family = "wasm"))]
extensions: Vec::new(),
Expand Down
Loading

0 comments on commit 5d1cc0f

Please sign in to comment.