From 288d91815a60193c2c0d5a32ccfc62977af44cdc Mon Sep 17 00:00:00 2001 From: Guilherme Andrade Date: Sun, 15 Oct 2023 20:10:21 +0100 Subject: [PATCH] Add test coverage of various edge cases, mostly errors This motivated me to raise argument errors, rather than return error tuples, when any API input is very clearly of the wrong type. --- lib/sqids.ex | 92 ++++++++++---- lib/sqids/agent.ex | 85 +++++++++---- lib/sqids/alphabet.ex | 19 ++- lib/sqids/blocklist.ex | 12 +- mix.exs | 2 +- test/sqids_test.exs | 265 +++++++++++++++++++++++++++++++++++++++-- 6 files changed, 413 insertions(+), 62 deletions(-) diff --git a/lib/sqids.ex b/lib/sqids.ex index 356c410..8b825bb 100644 --- a/lib/sqids.ex +++ b/lib/sqids.ex @@ -56,6 +56,15 @@ defmodule Sqids do blocklist: blocklist }} else + {:error, {tag, _} = reason} + when tag in [ + :alphabet_is_not_an_utf8_string, + :min_length_is_not_an_integer_in_range, + :blocklist_is_not_enumerable, + :some_words_in_blocklist_are_not_utf8_strings + ] -> + raise %ArgumentError{message: error_reason_to_string(reason)} + {:error, _} = error -> error end @@ -63,8 +72,13 @@ defmodule Sqids do @spec encode!(t(), enumerable(non_neg_integer)) :: String.t() def encode!(sqids, numbers) do - {:ok, string} = encode(sqids, numbers) - string + case encode(sqids, numbers) do + {:ok, string} -> + string + + {:error, {:all_id_generation_attempts_were_censored, _nr_of_attempts} = reason} -> + raise error_reason_to_string(reason) + end end @spec encode(t(), enumerable(non_neg_integer)) :: {:ok, String.t()} | {:error, term} @@ -73,18 +87,26 @@ defmodule Sqids do {:ok, numbers_list} -> encode_numbers(sqids, numbers_list) - {:error, _} = error -> - error + {:error, reason} -> + raise %ArgumentError{message: error_reason_to_string(reason)} end end + def encode(sqids, _numbers), do: :erlang.error({:badarg, sqids}) + @spec decode!(t(), String.t()) :: [non_neg_integer] def decode!(sqids, id) do - {:ok, numbers} = decode(sqids, id) - numbers + case decode(sqids, id) do + {:ok, numbers} -> + numbers + + # {:error, reason} -> + # raise error_reason_to_string(reason) + end end - @spec decode(t(), String.t()) :: {:ok, [non_neg_integer]} | {:error, term} + # | {:error, term} + @spec decode(t(), String.t()) :: {:ok, [non_neg_integer]} def decode(%Sqids{} = sqids, id) do case validate_id(sqids, id) do :ok -> @@ -98,11 +120,13 @@ defmodule Sqids do # Follow the spec's behaviour and return an empty list {:ok, []} - {:error, _} = error -> - error + {:error, {tag, _} = reason} when tag in [:id_is_not_utf8, :id_is_not_a_string] -> + raise %ArgumentError{message: error_reason_to_string(reason)} end end + def decode(sqids, _id), do: :erlang.error({:badarg, sqids}) + ## Internal Functions @doc false @@ -110,7 +134,7 @@ defmodule Sqids do defp validate_min_length(min_length) do if not is_integer(min_length) or min_length not in @min_length_range do - {:error, {:min_length_not_an_integer_in_range, min_length, range: @min_length_range}} + {:error, {:min_length_is_not_an_integer_in_range, value: min_length, range: @min_length_range}} else :ok end @@ -128,16 +152,17 @@ defmodule Sqids do end defp validate_numbers(numbers) do - numbers - |> Enum.find(&(not is_valid_number(&1))) - |> case do - nil -> - numbers_list = Enum.to_list(numbers) - {:ok, numbers_list} - - invalid_number -> - {:error, {:number_must_be_a_non_negative_integer, invalid_number}} - end + Enum.find(numbers, &(not is_valid_number(&1))) + catch + :error, %Protocol.UndefinedError{value: ^numbers} -> + {:error, {:numbers_not_enumerable, numbers}} + else + nil -> + numbers_list = Enum.to_list(numbers) + {:ok, numbers_list} + + invalid_number -> + {:error, {:number_is_not_a_non_negative_integer, invalid_number}} end defp is_valid_number(number), do: is_integer(number) and number >= 0 @@ -155,7 +180,7 @@ defmodule Sqids do defp attempt_to_encode_numbers(sqids, list, attempt_index) do if attempt_index > Alphabet.size(sqids.alphabet) do # We've reached max attempts - {:error, {:reached_max_attempts_to_regenerate_the_id, attempt_index - 1}} + {:error, {:all_id_generation_attempts_were_censored, _nr_of_attempts = attempt_index - 1}} else do_attempt_to_encode_numbers(sqids, list, attempt_index) end @@ -295,15 +320,20 @@ defmodule Sqids do defp validate_id(_sqids, ""), do: :empty_id defp validate_id(sqids, id) when is_binary(id) do - if are_all_chars_in_id_known(id, sqids.alphabet) do - :ok - else - :unknown_chars_in_id + case String.valid?(id) and {:all_chars_known, are_all_chars_in_id_known(id, sqids.alphabet)} do + {:all_chars_known, true} -> + :ok + + {:all_chars_known, false} -> + :unknown_chars_in_id + + false -> + {:error, {:id_is_not_utf8, id}} end end defp validate_id(_sqids, not_a_string) do - {:error, {:id_not_a_string, not_a_string}} + {:error, {:id_is_not_a_string, not_a_string}} end defp are_all_chars_in_id_known(id, alphabet) do @@ -380,6 +410,16 @@ defmodule Sqids do {:ok, numbers} end + defp error_reason_to_string({tag, details}) when is_atom(tag) do + "#{prettify_error_tag(tag)}: #{inspect(details)}" + end + + defp prettify_error_tag(tag) do + [first_word | next_words] = tag |> Atom.to_string() |> String.split("_") + first_word = String.capitalize(first_word) + Enum.join([first_word | next_words], " ") + end + ## Code generation defmacro __using__([]) do diff --git a/lib/sqids/agent.ex b/lib/sqids/agent.ex index 1974dd4..3de94f5 100644 --- a/lib/sqids/agent.ex +++ b/lib/sqids/agent.ex @@ -1,7 +1,7 @@ defmodule Sqids.Agent do @moduledoc """ - Storage for `Sqids` shared state. - Like stdlib's [Agent](https://hexdocs.pm/elixir/1.15/Agent.html) but using + Storage for `Sqids` shared state. + Like stdlib's [Agent](https://hexdocs.pm/elixir/1.15/Agent.html) but using OTP's [`persistent_term`](https://www.erlang.org/doc/man/persistent_term). """ @@ -32,8 +32,16 @@ defmodule Sqids.Agent do shared_state_init: shared_state_init ] - server_name = server_name(sqids_module) - GenServer.start_link(__MODULE__, init_args, name: server_name) + case :proc_lib.start_link(__MODULE__, :proc_lib_init, [init_args]) do + {:ok, _} = success -> + success + + {:error, _} = error -> + error + + {:intentional_raise, reason, stacktrace} -> + :erlang.raise(:error, reason, stacktrace) + end end @doc false @@ -55,27 +63,29 @@ defmodule Sqids.Agent do ## GenServer callbacks @doc false - @impl true - @spec init(init_args) :: {:ok, state} | no_return() - def init(init_args) do + @spec proc_lib_init(init_args) :: no_return() + def proc_lib_init(init_args) do sqids_module = Keyword.fetch!(init_args, :sqids_module) - {shared_state_init_fun, shared_state_args} = Keyword.fetch!(init_args, :shared_state_init) - - case apply(shared_state_init_fun, shared_state_args) do - {:ok, shared_state} -> - # Ensure `:terminate/2` gets called unless we're killed - _ = Process.flag(:trap_exit, true) - - shared_state_key = shared_state_key(sqids_module) - :persistent_term.put(shared_state_key, shared_state) - state = state(shared_state_key: shared_state_key) - {:ok, state} + server_name = server_name(sqids_module) - {:error, _} = error -> - init_fail(error, sqids_module) + try do + Process.register(self(), server_name) + catch + :error, %ArgumentError{} when is_atom(server_name) -> + init_fail({:error, {:already_started, Process.whereis(server_name)}}, server_name) + else + true -> + proc_lib_init_registered(init_args, sqids_module, server_name) end end + @doc false + @impl true + @spec init(term) :: no_return() + def init(_init_args) do + raise "Initialization is done through :proc_lib_init/1" + end + @doc false @impl true @spec terminate(term, state) :: term @@ -95,7 +105,39 @@ defmodule Sqids.Agent do ## Internal - defp init_fail(error, sqids_module) do + defp proc_lib_init_registered(init_args, sqids_module, server_name) do + {shared_state_init_fun, shared_state_args} = Keyword.fetch!(init_args, :shared_state_init) + + try do + apply(shared_state_init_fun, shared_state_args) + catch + :error, %ArgumentError{} = reason -> + stacktrace = __STACKTRACE__ + init_fail({:intentional_raise, reason, stacktrace}, server_name) + else + {:ok, shared_state} -> + # Ensure `:terminate/2` gets called unless we're killed + _ = Process.flag(:trap_exit, true) + + shared_state_key = shared_state_key(sqids_module) + :persistent_term.put(shared_state_key, shared_state) + state = state(shared_state_key: shared_state_key) + :proc_lib.init_ack({:ok, self()}) + + :gen_server.enter_loop( + __MODULE__, + _enter_loop_opts = [], + state, + {:local, server_name}, + :hibernate + ) + + {:error, _} = error -> + init_fail(error, server_name) + end + end + + defp init_fail(error, server_name) do # Use proc_lib:init_fail/2 instead of {:stop, reason} to avoid # polluting the logs: our supervisor will fail to start us and this # will already produce log messages with the relevant info. @@ -106,7 +148,6 @@ defmodule Sqids.Agent do catch :error, :undef -> # Fallback for OTP 25 or older - server_name = server_name(sqids_module) Process.unregister(server_name) :proc_lib.init_ack(error) :erlang.exit(:normal) diff --git a/lib/sqids/alphabet.ex b/lib/sqids/alphabet.ex index f3308e0..ac6517b 100644 --- a/lib/sqids/alphabet.ex +++ b/lib/sqids/alphabet.ex @@ -9,11 +9,18 @@ defmodule Sqids.Alphabet do @opaque t :: %{required(index) => byte} @type index :: non_neg_integer + @type new_error_reason :: + {:alphabet_is_not_an_utf8_string, term} + | {:alphabet_contains_multibyte_graphemes, [String.grapheme(), ...]} + | {:alphabet_is_too_small, [min_length: pos_integer, alphabet: String.t()]} + | {:alphabet_contains_repeated_graphemes, [String.grapheme(), ...]} + ## API - @spec new_shuffled(String.t()) :: {:ok, t} | {:error, term} + @spec new_shuffled(term) :: {:ok, t} | {:error, new_error_reason} def new_shuffled(alphabet_str) do - with :ok <- validate_alphabet_graphemes_are_not_multibyte(alphabet_str), + with :ok <- validate_alphabet_is_utf8_string(alphabet_str), + :ok <- validate_alphabet_graphemes_are_not_multibyte(alphabet_str), :ok <- validate_alphabet_length(alphabet_str), :ok <- validate_alphabet_has_unique_chars(alphabet_str) do alphabet = alphabet_str |> new_from_valid_str!() |> shuffle() @@ -110,6 +117,14 @@ defmodule Sqids.Alphabet do ## Internal + defp validate_alphabet_is_utf8_string(alphabet_str) do + if is_binary(alphabet_str) and String.valid?(alphabet_str) do + :ok + else + {:error, {:alphabet_is_not_an_utf8_string, alphabet_str}} + end + end + defp validate_alphabet_graphemes_are_not_multibyte(alphabet_str) do alphabet_graphemes = String.graphemes(alphabet_str) diff --git a/lib/sqids/blocklist.ex b/lib/sqids/blocklist.ex index b62f8c2..d811e5f 100644 --- a/lib/sqids/blocklist.ex +++ b/lib/sqids/blocklist.ex @@ -14,9 +14,13 @@ defmodule Sqids.Blocklist do matches_anywhere: [String.t()] } + @type new_error_reason :: + {:blocklist_is_not_enumerable, term} + | {:some_words_in_blocklist_are_not_utf8_strings, [term, ...]} + ## API Functions - @spec new(term, non_neg_integer, String.t()) :: {:ok, t()} | {:error, term} + @spec new(term, non_neg_integer, String.t()) :: {:ok, t()} | {:error, new_error_reason} def new(words, min_word_length, alphabet_str) do case validate_words(words) do :ok -> @@ -53,14 +57,14 @@ defmodule Sqids.Blocklist do defp validate_words(words) do Enum.filter(words, &(not is_binary(&1) or not String.valid?(&1))) catch - :error, _ -> - {:error, {:invalid_blocklist, words}} + :error, %Protocol.UndefinedError{value: ^words} -> + {:error, {:blocklist_is_not_enumerable, words}} else [] -> :ok invalid_words -> - {:error, {:invalid_words_in_blocklist, invalid_words}} + {:error, {:some_words_in_blocklist_are_not_utf8_strings, invalid_words}} end @spec new_for_valid_words(Sqids.enumerable(String.t()), non_neg_integer, String.t()) :: t() diff --git a/mix.exs b/mix.exs index 71e6ada..f892f91 100644 --- a/mix.exs +++ b/mix.exs @@ -23,7 +23,7 @@ defmodule Sqids.MixProject do ], test_coverage: [ summary: [ - threshold: 90 + threshold: 95 ] ], package: package() diff --git a/test/sqids_test.exs b/test/sqids_test.exs index d7c54f5..45a6d53 100644 --- a/test/sqids_test.exs +++ b/test/sqids_test.exs @@ -51,6 +51,10 @@ defmodule SqidsTest do call_instance_fun(instance, :decode!, [id]) end + def decode(instance, id) do + call_instance_fun(instance, :decode, [id]) + end + def assert_encode_and_back(sqids, numbers) do import ExUnit.Assertions assert decode!(sqids, encode!(sqids, numbers)) === numbers @@ -222,7 +226,9 @@ defmodule SqidsTest do {:ok, instance} = new_sqids(unquote(access_type), alphabet: alphabet, min_length: min_length, blocklist: blocklist) - assert encode(instance, [0]) === {:error, {:reached_max_attempts_to_regenerate_the_id, 3}} + input = [0] + assert encode(instance, input) === {:error, {:all_id_generation_attempts_were_censored, 3}} + assert_raise RuntimeError, "All id generation attempts were censored: 3", fn -> encode!(instance, input) end end end end @@ -457,14 +463,259 @@ defmodule SqidsTest do # for those langs that don't support `u8` test "#{access_type}: out-of-range invalid min length" do - assert new_sqids(unquote(access_type), min_length: -1) == - {:error, {:min_length_not_an_integer_in_range, -1, range: 0..255}} + assert_raise ArgumentError, "Min length is not an integer in range: [value: -1, range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: -1) + end + + assert_raise ArgumentError, "Min length is not an integer in range: [value: 256, range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: 256) + end + + assert_raise ArgumentError, "Min length is not an integer in range: [value: \"1\", range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: "1") + end + end + end + end + + defmodule AdditionalInstantiationScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: new/2: alphabet is not an UTF-8 string" do + at = unquote(access_type) + + input = [3] + assert_raise ArgumentError, "Alphabet is not an utf8 string: [3]", fn -> new_sqids(at, alphabet: input) end + + input = ~c"abcdf" + + assert_raise ArgumentError, "Alphabet is not an utf8 string: ~c\"abcdf\"", fn -> + new_sqids(at, alphabet: input) + end + + input = <<128>> + assert_raise ArgumentError, "Alphabet is not an utf8 string: <<128>>", fn -> new_sqids(at, alphabet: input) end + end + + test "#{access_type}: new/2: blocklist is not enumerable" do + at = unquote(access_type) + + input = {"word"} + assert_raise ArgumentError, "Blocklist is not enumerable: {\"word\"}", fn -> new_sqids(at, blocklist: input) end + + input = 42.456 + assert_raise ArgumentError, "Blocklist is not enumerable: 42.456", fn -> new_sqids(at, blocklist: input) end + + input = "555" + assert_raise ArgumentError, "Blocklist is not enumerable: \"555\"", fn -> new_sqids(at, blocklist: input) end + end + + test "#{access_type}: new/2: some words in blocklist are not UTF-8 strings " do + at = unquote(access_type) + + input = ["aaaa", -44.3, "ok", 5, "go", <<128>>, <<129>>, "done"] + + assert_raise ArgumentError, "Some words in blocklist are not utf8 strings: [-44.3, 5, <<128>>, <<129>>]", fn -> + new_sqids(at, blocklist: input) + end + end + end + + test "Blocklist: short words are not blocked" do + alphabet_str = "abc" + {:ok, blocklist} = Sqids.Blocklist.new(["abc"], _min_word_length = 4, alphabet_str) + refute Sqids.Blocklist.is_blocked_id(blocklist, "abc") + end + + test "Stopped agent" do + assert_raise RuntimeError, ~r/Sqids shared state not found/, fn -> Sqids.Agent.get(RandomModule354343) end + end + end + + defmodule AdditionalEncodingScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + test "encode/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.encode(sqids, [33]) end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.encode(sqids, [33]) end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.encode(sqids, [33]) end + end + + test "encode!/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.encode!(sqids, [33]) end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.encode!(sqids, [33]) end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.encode!(sqids, [33]) end + end + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: encode/2: number is not a non negative integer" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = [-1] + assert_raise ArgumentError, "Number is not a non negative integer: -1", fn -> encode(instance, input) end + + input = [332, 43_543, -5, 23_434] + assert_raise ArgumentError, "Number is not a non negative integer: -5", fn -> encode(instance, input) end + + input = [332, 43_543, 23_434, 233, -10] + assert_raise ArgumentError, "Number is not a non negative integer: -10", fn -> encode(instance, input) end + + input = [55, "Oh no"] + assert_raise ArgumentError, "Number is not a non negative integer: \"Oh no\"", fn -> encode(instance, input) end + end + + test "#{access_type}: encode!/2: number is not a non negative integer" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = [-1] + assert_raise ArgumentError, "Number is not a non negative integer: -1", fn -> encode!(instance, input) end + + input = [332, 43_543, -5, 23_434] + assert_raise ArgumentError, "Number is not a non negative integer: -5", fn -> encode!(instance, input) end + + input = [332, 43_543, 23_434, 233, -10] + assert_raise ArgumentError, "Number is not a non negative integer: -10", fn -> encode!(instance, input) end + + input = [55, "Oh no"] + assert_raise ArgumentError, "Number is not a non negative integer: \"Oh no\"", fn -> encode!(instance, input) end + end + + test "#{access_type}: encode/2: numbers not enumerable" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = {55} + assert_raise ArgumentError, "Numbers not enumerable: {55}", fn -> encode(instance, input) end + + input = 3.5346 + assert_raise ArgumentError, "Numbers not enumerable: 3.5346", fn -> encode(instance, input) end + + input = "56" + assert_raise ArgumentError, "Numbers not enumerable: \"56\"", fn -> encode(instance, input) end + + input = :"42" + assert_raise ArgumentError, "Numbers not enumerable: :\"42\"", fn -> encode(instance, input) end + end + + test "#{access_type}: encode!/2: numbers not enumerable" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = {55} + assert_raise ArgumentError, "Numbers not enumerable: {55}", fn -> encode!(instance, input) end + + input = 3.5346 + assert_raise ArgumentError, "Numbers not enumerable: 3.5346", fn -> encode!(instance, input) end + + input = "56" + assert_raise ArgumentError, "Numbers not enumerable: \"56\"", fn -> encode!(instance, input) end + + input = :"42" + assert_raise ArgumentError, "Numbers not enumerable: :\"42\"", fn -> encode!(instance, input) end + end + end + end + + defmodule AdditionalDecodingScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + test "decode/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.decode(sqids, "0") end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.decode(sqids, "0") end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.decode(sqids, "0") end + end + + test "decode!/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.decode!(sqids, "0") end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.decode!(sqids, "0") end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.decode!(sqids, "0") end + end + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: decode/2: id is not a string or valid UTF-8" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = ~c"555" + assert_raise ArgumentError, "Id is not a string: ~c\"555\"", fn -> decode(instance, input) end + + input = 10_432_345 + assert_raise ArgumentError, "Id is not a string: 10432345", fn -> decode(instance, input) end + + input = "0000" <> <<128>> + assert_raise ArgumentError, "Id is not utf8: <<48, 48, 48, 48, 128>>", fn -> decode(instance, input) end + end + + test "#{access_type}: decode!/2: id is not a string or valid UTF-8" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = ~c"555" + assert_raise ArgumentError, "Id is not a string: ~c\"555\"", fn -> decode!(instance, input) end + + input = 10_432_345 + assert_raise ArgumentError, "Id is not a string: 10432345", fn -> decode!(instance, input) end + + input = "0000" <> <<128>> + assert_raise ArgumentError, "Id is not utf8: <<48, 48, 48, 48, 128>>", fn -> decode!(instance, input) end + end + + test "#{access_type}: decode/2: id has unknown chars" do + {:ok, instance} = new_sqids(unquote(access_type), alphabet: "01234") + + input = "5" + assert decode(instance, input) === {:ok, []} + + input = "011015" + assert decode(instance, input) === {:ok, []} + + input = "011015143" + assert decode(instance, input) === {:ok, []} + + input = "000ë5" + assert decode(instance, input) === {:ok, []} + end + + test "#{access_type}: decode!/2: id has unknown chars" do + {:ok, instance} = new_sqids(unquote(access_type), alphabet: "01234") + + input = "5" + assert decode!(instance, input) === [] + + input = "011015" + assert decode!(instance, input) === [] - assert new_sqids(unquote(access_type), min_length: 256) == - {:error, {:min_length_not_an_integer_in_range, 256, range: 0..255}} + input = "011015143" + assert decode!(instance, input) === [] - assert new_sqids(unquote(access_type), min_length: "1") == - {:error, {:min_length_not_an_integer_in_range, "1", range: 0..255}} + input = "000ë5" + assert decode!(instance, input) === [] end end end