Skip to content

Commit

Permalink
Add default value
Browse files Browse the repository at this point in the history
  • Loading branch information
jigordev committed Jul 28, 2024
1 parent 07b5c42 commit 2300aca
Showing 1 changed file with 14 additions and 7 deletions.
21 changes: 14 additions & 7 deletions src/checkargs.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,46 +19,53 @@ local function check(use_error, condition, message)
end
end

function checkargs.check_arg(func, name, expected, value, optional, use_error)
function checkargs.check_arg(func, name, expected, value, optional, default, use_error)
check(use_error, contains(expected, type(value)) or (optional and value == nil),
string.format("Error in %s: Argument '%s' must be a %s, got: %s", func, name, table.concat(expected, ", "),
type(value)))
return value or default
end

function checkargs.check_list(func, name, expected, list, optional, use_error)
local args = {}
for _, arg in ipairs(list) do
checkargs.check_arg(func, name, expected, arg, optional, use_error)
table.insert(args, checkargs.check_arg(func, name, expected, arg, optional, use_error))
end
return args
end

function checkargs.check_range(func, name, value, min, max, use_error)
function checkargs.check_range(func, name, value, min, max, default, use_error)
check(use_error, type(value) == "number",
string.format("Error in %s: Argument '%s' must be a number, got: %s", func, name, type(value)))
check(use_error, value >= min and value <= max,
string.format("Error in %s: Argument '%s' must be between %d and %d, got: %d", func, name, min, max, value))
return value or default
end

function checkargs.check_fields(func, name, tbl, fields, use_error)
function checkargs.check_fields(func, name, tbl, fields, default, use_error)
check(use_error, type(tbl) == "table",
string.format("Error in %s: Argument '%s' must be a table, got: %s", func, name, type(tbl)))
for _, field in ipairs(fields) do
check(use_error, tbl[field] ~= nil,
string.format("Error in %s: Table '%s' must contain field '%s'", func, name, field))
end
return tbl or default
end

function checkargs.check_composite(func, name, value, expected_fields, use_error)
function checkargs.check_composite(func, name, value, expected_fields, default, use_error)
check(use_error, type(value) == "table",
string.format("Error in %s: Argument '%s' must be a table, got: %s", func, name, type(value)))
for field, field_type in pairs(expected_fields) do
check(use_error, type(value[field]) == field_type,
string.format("Error in %s: Field '%s' in argument '%s' must be a %s, got: %s", func, field, name, field_type,
type(value[field])))
end
return value or default
end

function checkargs.check_not_nil(func, name, value, use_error)
function checkargs.check_not_nil(func, name, value, default, use_error)
check(use_error, value ~= nil, string.format("Error in %s: Argument '%s' must not be nil", func, name))
return value or default
end

return checkargs
return checkargs

0 comments on commit 2300aca

Please sign in to comment.