diff --git a/lib/sqids.ex b/lib/sqids.ex index 356c410..8b825bb 100644 --- a/lib/sqids.ex +++ b/lib/sqids.ex @@ -56,6 +56,15 @@ defmodule Sqids do blocklist: blocklist }} else + {:error, {tag, _} = reason} + when tag in [ + :alphabet_is_not_an_utf8_string, + :min_length_is_not_an_integer_in_range, + :blocklist_is_not_enumerable, + :some_words_in_blocklist_are_not_utf8_strings + ] -> + raise %ArgumentError{message: error_reason_to_string(reason)} + {:error, _} = error -> error end @@ -63,8 +72,13 @@ defmodule Sqids do @spec encode!(t(), enumerable(non_neg_integer)) :: String.t() def encode!(sqids, numbers) do - {:ok, string} = encode(sqids, numbers) - string + case encode(sqids, numbers) do + {:ok, string} -> + string + + {:error, {:all_id_generation_attempts_were_censored, _nr_of_attempts} = reason} -> + raise error_reason_to_string(reason) + end end @spec encode(t(), enumerable(non_neg_integer)) :: {:ok, String.t()} | {:error, term} @@ -73,18 +87,26 @@ defmodule Sqids do {:ok, numbers_list} -> encode_numbers(sqids, numbers_list) - {:error, _} = error -> - error + {:error, reason} -> + raise %ArgumentError{message: error_reason_to_string(reason)} end end + def encode(sqids, _numbers), do: :erlang.error({:badarg, sqids}) + @spec decode!(t(), String.t()) :: [non_neg_integer] def decode!(sqids, id) do - {:ok, numbers} = decode(sqids, id) - numbers + case decode(sqids, id) do + {:ok, numbers} -> + numbers + + # {:error, reason} -> + # raise error_reason_to_string(reason) + end end - @spec decode(t(), String.t()) :: {:ok, [non_neg_integer]} | {:error, term} + # | {:error, term} + @spec decode(t(), String.t()) :: {:ok, [non_neg_integer]} def decode(%Sqids{} = sqids, id) do case validate_id(sqids, id) do :ok -> @@ -98,11 +120,13 @@ defmodule Sqids do # Follow the spec's behaviour and return an empty list {:ok, []} - {:error, _} = error -> - error + {:error, {tag, _} = reason} when tag in [:id_is_not_utf8, :id_is_not_a_string] -> + raise %ArgumentError{message: error_reason_to_string(reason)} end end + def decode(sqids, _id), do: :erlang.error({:badarg, sqids}) + ## Internal Functions @doc false @@ -110,7 +134,7 @@ defmodule Sqids do defp validate_min_length(min_length) do if not is_integer(min_length) or min_length not in @min_length_range do - {:error, {:min_length_not_an_integer_in_range, min_length, range: @min_length_range}} + {:error, {:min_length_is_not_an_integer_in_range, value: min_length, range: @min_length_range}} else :ok end @@ -128,16 +152,17 @@ defmodule Sqids do end defp validate_numbers(numbers) do - numbers - |> Enum.find(&(not is_valid_number(&1))) - |> case do - nil -> - numbers_list = Enum.to_list(numbers) - {:ok, numbers_list} - - invalid_number -> - {:error, {:number_must_be_a_non_negative_integer, invalid_number}} - end + Enum.find(numbers, &(not is_valid_number(&1))) + catch + :error, %Protocol.UndefinedError{value: ^numbers} -> + {:error, {:numbers_not_enumerable, numbers}} + else + nil -> + numbers_list = Enum.to_list(numbers) + {:ok, numbers_list} + + invalid_number -> + {:error, {:number_is_not_a_non_negative_integer, invalid_number}} end defp is_valid_number(number), do: is_integer(number) and number >= 0 @@ -155,7 +180,7 @@ defmodule Sqids do defp attempt_to_encode_numbers(sqids, list, attempt_index) do if attempt_index > Alphabet.size(sqids.alphabet) do # We've reached max attempts - {:error, {:reached_max_attempts_to_regenerate_the_id, attempt_index - 1}} + {:error, {:all_id_generation_attempts_were_censored, _nr_of_attempts = attempt_index - 1}} else do_attempt_to_encode_numbers(sqids, list, attempt_index) end @@ -295,15 +320,20 @@ defmodule Sqids do defp validate_id(_sqids, ""), do: :empty_id defp validate_id(sqids, id) when is_binary(id) do - if are_all_chars_in_id_known(id, sqids.alphabet) do - :ok - else - :unknown_chars_in_id + case String.valid?(id) and {:all_chars_known, are_all_chars_in_id_known(id, sqids.alphabet)} do + {:all_chars_known, true} -> + :ok + + {:all_chars_known, false} -> + :unknown_chars_in_id + + false -> + {:error, {:id_is_not_utf8, id}} end end defp validate_id(_sqids, not_a_string) do - {:error, {:id_not_a_string, not_a_string}} + {:error, {:id_is_not_a_string, not_a_string}} end defp are_all_chars_in_id_known(id, alphabet) do @@ -380,6 +410,16 @@ defmodule Sqids do {:ok, numbers} end + defp error_reason_to_string({tag, details}) when is_atom(tag) do + "#{prettify_error_tag(tag)}: #{inspect(details)}" + end + + defp prettify_error_tag(tag) do + [first_word | next_words] = tag |> Atom.to_string() |> String.split("_") + first_word = String.capitalize(first_word) + Enum.join([first_word | next_words], " ") + end + ## Code generation defmacro __using__([]) do diff --git a/lib/sqids/agent.ex b/lib/sqids/agent.ex index 1974dd4..3de94f5 100644 --- a/lib/sqids/agent.ex +++ b/lib/sqids/agent.ex @@ -1,7 +1,7 @@ defmodule Sqids.Agent do @moduledoc """ - Storage for `Sqids` shared state. - Like stdlib's [Agent](https://hexdocs.pm/elixir/1.15/Agent.html) but using + Storage for `Sqids` shared state. + Like stdlib's [Agent](https://hexdocs.pm/elixir/1.15/Agent.html) but using OTP's [`persistent_term`](https://www.erlang.org/doc/man/persistent_term). """ @@ -32,8 +32,16 @@ defmodule Sqids.Agent do shared_state_init: shared_state_init ] - server_name = server_name(sqids_module) - GenServer.start_link(__MODULE__, init_args, name: server_name) + case :proc_lib.start_link(__MODULE__, :proc_lib_init, [init_args]) do + {:ok, _} = success -> + success + + {:error, _} = error -> + error + + {:intentional_raise, reason, stacktrace} -> + :erlang.raise(:error, reason, stacktrace) + end end @doc false @@ -55,27 +63,29 @@ defmodule Sqids.Agent do ## GenServer callbacks @doc false - @impl true - @spec init(init_args) :: {:ok, state} | no_return() - def init(init_args) do + @spec proc_lib_init(init_args) :: no_return() + def proc_lib_init(init_args) do sqids_module = Keyword.fetch!(init_args, :sqids_module) - {shared_state_init_fun, shared_state_args} = Keyword.fetch!(init_args, :shared_state_init) - - case apply(shared_state_init_fun, shared_state_args) do - {:ok, shared_state} -> - # Ensure `:terminate/2` gets called unless we're killed - _ = Process.flag(:trap_exit, true) - - shared_state_key = shared_state_key(sqids_module) - :persistent_term.put(shared_state_key, shared_state) - state = state(shared_state_key: shared_state_key) - {:ok, state} + server_name = server_name(sqids_module) - {:error, _} = error -> - init_fail(error, sqids_module) + try do + Process.register(self(), server_name) + catch + :error, %ArgumentError{} when is_atom(server_name) -> + init_fail({:error, {:already_started, Process.whereis(server_name)}}, server_name) + else + true -> + proc_lib_init_registered(init_args, sqids_module, server_name) end end + @doc false + @impl true + @spec init(term) :: no_return() + def init(_init_args) do + raise "Initialization is done through :proc_lib_init/1" + end + @doc false @impl true @spec terminate(term, state) :: term @@ -95,7 +105,39 @@ defmodule Sqids.Agent do ## Internal - defp init_fail(error, sqids_module) do + defp proc_lib_init_registered(init_args, sqids_module, server_name) do + {shared_state_init_fun, shared_state_args} = Keyword.fetch!(init_args, :shared_state_init) + + try do + apply(shared_state_init_fun, shared_state_args) + catch + :error, %ArgumentError{} = reason -> + stacktrace = __STACKTRACE__ + init_fail({:intentional_raise, reason, stacktrace}, server_name) + else + {:ok, shared_state} -> + # Ensure `:terminate/2` gets called unless we're killed + _ = Process.flag(:trap_exit, true) + + shared_state_key = shared_state_key(sqids_module) + :persistent_term.put(shared_state_key, shared_state) + state = state(shared_state_key: shared_state_key) + :proc_lib.init_ack({:ok, self()}) + + :gen_server.enter_loop( + __MODULE__, + _enter_loop_opts = [], + state, + {:local, server_name}, + :hibernate + ) + + {:error, _} = error -> + init_fail(error, server_name) + end + end + + defp init_fail(error, server_name) do # Use proc_lib:init_fail/2 instead of {:stop, reason} to avoid # polluting the logs: our supervisor will fail to start us and this # will already produce log messages with the relevant info. @@ -106,7 +148,6 @@ defmodule Sqids.Agent do catch :error, :undef -> # Fallback for OTP 25 or older - server_name = server_name(sqids_module) Process.unregister(server_name) :proc_lib.init_ack(error) :erlang.exit(:normal) diff --git a/lib/sqids/alphabet.ex b/lib/sqids/alphabet.ex index f3308e0..ac6517b 100644 --- a/lib/sqids/alphabet.ex +++ b/lib/sqids/alphabet.ex @@ -9,11 +9,18 @@ defmodule Sqids.Alphabet do @opaque t :: %{required(index) => byte} @type index :: non_neg_integer + @type new_error_reason :: + {:alphabet_is_not_an_utf8_string, term} + | {:alphabet_contains_multibyte_graphemes, [String.grapheme(), ...]} + | {:alphabet_is_too_small, [min_length: pos_integer, alphabet: String.t()]} + | {:alphabet_contains_repeated_graphemes, [String.grapheme(), ...]} + ## API - @spec new_shuffled(String.t()) :: {:ok, t} | {:error, term} + @spec new_shuffled(term) :: {:ok, t} | {:error, new_error_reason} def new_shuffled(alphabet_str) do - with :ok <- validate_alphabet_graphemes_are_not_multibyte(alphabet_str), + with :ok <- validate_alphabet_is_utf8_string(alphabet_str), + :ok <- validate_alphabet_graphemes_are_not_multibyte(alphabet_str), :ok <- validate_alphabet_length(alphabet_str), :ok <- validate_alphabet_has_unique_chars(alphabet_str) do alphabet = alphabet_str |> new_from_valid_str!() |> shuffle() @@ -110,6 +117,14 @@ defmodule Sqids.Alphabet do ## Internal + defp validate_alphabet_is_utf8_string(alphabet_str) do + if is_binary(alphabet_str) and String.valid?(alphabet_str) do + :ok + else + {:error, {:alphabet_is_not_an_utf8_string, alphabet_str}} + end + end + defp validate_alphabet_graphemes_are_not_multibyte(alphabet_str) do alphabet_graphemes = String.graphemes(alphabet_str) diff --git a/lib/sqids/blocklist.ex b/lib/sqids/blocklist.ex index b62f8c2..d811e5f 100644 --- a/lib/sqids/blocklist.ex +++ b/lib/sqids/blocklist.ex @@ -14,9 +14,13 @@ defmodule Sqids.Blocklist do matches_anywhere: [String.t()] } + @type new_error_reason :: + {:blocklist_is_not_enumerable, term} + | {:some_words_in_blocklist_are_not_utf8_strings, [term, ...]} + ## API Functions - @spec new(term, non_neg_integer, String.t()) :: {:ok, t()} | {:error, term} + @spec new(term, non_neg_integer, String.t()) :: {:ok, t()} | {:error, new_error_reason} def new(words, min_word_length, alphabet_str) do case validate_words(words) do :ok -> @@ -53,14 +57,14 @@ defmodule Sqids.Blocklist do defp validate_words(words) do Enum.filter(words, &(not is_binary(&1) or not String.valid?(&1))) catch - :error, _ -> - {:error, {:invalid_blocklist, words}} + :error, %Protocol.UndefinedError{value: ^words} -> + {:error, {:blocklist_is_not_enumerable, words}} else [] -> :ok invalid_words -> - {:error, {:invalid_words_in_blocklist, invalid_words}} + {:error, {:some_words_in_blocklist_are_not_utf8_strings, invalid_words}} end @spec new_for_valid_words(Sqids.enumerable(String.t()), non_neg_integer, String.t()) :: t() diff --git a/mix.exs b/mix.exs index 71e6ada..f892f91 100644 --- a/mix.exs +++ b/mix.exs @@ -23,7 +23,7 @@ defmodule Sqids.MixProject do ], test_coverage: [ summary: [ - threshold: 90 + threshold: 95 ] ], package: package() diff --git a/test/sqids_test.exs b/test/sqids_test.exs index d7c54f5..45a6d53 100644 --- a/test/sqids_test.exs +++ b/test/sqids_test.exs @@ -51,6 +51,10 @@ defmodule SqidsTest do call_instance_fun(instance, :decode!, [id]) end + def decode(instance, id) do + call_instance_fun(instance, :decode, [id]) + end + def assert_encode_and_back(sqids, numbers) do import ExUnit.Assertions assert decode!(sqids, encode!(sqids, numbers)) === numbers @@ -222,7 +226,9 @@ defmodule SqidsTest do {:ok, instance} = new_sqids(unquote(access_type), alphabet: alphabet, min_length: min_length, blocklist: blocklist) - assert encode(instance, [0]) === {:error, {:reached_max_attempts_to_regenerate_the_id, 3}} + input = [0] + assert encode(instance, input) === {:error, {:all_id_generation_attempts_were_censored, 3}} + assert_raise RuntimeError, "All id generation attempts were censored: 3", fn -> encode!(instance, input) end end end end @@ -457,14 +463,259 @@ defmodule SqidsTest do # for those langs that don't support `u8` test "#{access_type}: out-of-range invalid min length" do - assert new_sqids(unquote(access_type), min_length: -1) == - {:error, {:min_length_not_an_integer_in_range, -1, range: 0..255}} + assert_raise ArgumentError, "Min length is not an integer in range: [value: -1, range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: -1) + end + + assert_raise ArgumentError, "Min length is not an integer in range: [value: 256, range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: 256) + end + + assert_raise ArgumentError, "Min length is not an integer in range: [value: \"1\", range: 0..255]", fn -> + new_sqids(unquote(access_type), min_length: "1") + end + end + end + end + + defmodule AdditionalInstantiationScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: new/2: alphabet is not an UTF-8 string" do + at = unquote(access_type) + + input = [3] + assert_raise ArgumentError, "Alphabet is not an utf8 string: [3]", fn -> new_sqids(at, alphabet: input) end + + input = ~c"abcdf" + + assert_raise ArgumentError, "Alphabet is not an utf8 string: ~c\"abcdf\"", fn -> + new_sqids(at, alphabet: input) + end + + input = <<128>> + assert_raise ArgumentError, "Alphabet is not an utf8 string: <<128>>", fn -> new_sqids(at, alphabet: input) end + end + + test "#{access_type}: new/2: blocklist is not enumerable" do + at = unquote(access_type) + + input = {"word"} + assert_raise ArgumentError, "Blocklist is not enumerable: {\"word\"}", fn -> new_sqids(at, blocklist: input) end + + input = 42.456 + assert_raise ArgumentError, "Blocklist is not enumerable: 42.456", fn -> new_sqids(at, blocklist: input) end + + input = "555" + assert_raise ArgumentError, "Blocklist is not enumerable: \"555\"", fn -> new_sqids(at, blocklist: input) end + end + + test "#{access_type}: new/2: some words in blocklist are not UTF-8 strings " do + at = unquote(access_type) + + input = ["aaaa", -44.3, "ok", 5, "go", <<128>>, <<129>>, "done"] + + assert_raise ArgumentError, "Some words in blocklist are not utf8 strings: [-44.3, 5, <<128>>, <<129>>]", fn -> + new_sqids(at, blocklist: input) + end + end + end + + test "Blocklist: short words are not blocked" do + alphabet_str = "abc" + {:ok, blocklist} = Sqids.Blocklist.new(["abc"], _min_word_length = 4, alphabet_str) + refute Sqids.Blocklist.is_blocked_id(blocklist, "abc") + end + + test "Stopped agent" do + assert_raise RuntimeError, ~r/Sqids shared state not found/, fn -> Sqids.Agent.get(RandomModule354343) end + end + end + + defmodule AdditionalEncodingScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + test "encode/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.encode(sqids, [33]) end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.encode(sqids, [33]) end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.encode(sqids, [33]) end + end + + test "encode!/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.encode!(sqids, [33]) end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.encode!(sqids, [33]) end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.encode!(sqids, [33]) end + end + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: encode/2: number is not a non negative integer" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = [-1] + assert_raise ArgumentError, "Number is not a non negative integer: -1", fn -> encode(instance, input) end + + input = [332, 43_543, -5, 23_434] + assert_raise ArgumentError, "Number is not a non negative integer: -5", fn -> encode(instance, input) end + + input = [332, 43_543, 23_434, 233, -10] + assert_raise ArgumentError, "Number is not a non negative integer: -10", fn -> encode(instance, input) end + + input = [55, "Oh no"] + assert_raise ArgumentError, "Number is not a non negative integer: \"Oh no\"", fn -> encode(instance, input) end + end + + test "#{access_type}: encode!/2: number is not a non negative integer" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = [-1] + assert_raise ArgumentError, "Number is not a non negative integer: -1", fn -> encode!(instance, input) end + + input = [332, 43_543, -5, 23_434] + assert_raise ArgumentError, "Number is not a non negative integer: -5", fn -> encode!(instance, input) end + + input = [332, 43_543, 23_434, 233, -10] + assert_raise ArgumentError, "Number is not a non negative integer: -10", fn -> encode!(instance, input) end + + input = [55, "Oh no"] + assert_raise ArgumentError, "Number is not a non negative integer: \"Oh no\"", fn -> encode!(instance, input) end + end + + test "#{access_type}: encode/2: numbers not enumerable" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = {55} + assert_raise ArgumentError, "Numbers not enumerable: {55}", fn -> encode(instance, input) end + + input = 3.5346 + assert_raise ArgumentError, "Numbers not enumerable: 3.5346", fn -> encode(instance, input) end + + input = "56" + assert_raise ArgumentError, "Numbers not enumerable: \"56\"", fn -> encode(instance, input) end + + input = :"42" + assert_raise ArgumentError, "Numbers not enumerable: :\"42\"", fn -> encode(instance, input) end + end + + test "#{access_type}: encode!/2: numbers not enumerable" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = {55} + assert_raise ArgumentError, "Numbers not enumerable: {55}", fn -> encode!(instance, input) end + + input = 3.5346 + assert_raise ArgumentError, "Numbers not enumerable: 3.5346", fn -> encode!(instance, input) end + + input = "56" + assert_raise ArgumentError, "Numbers not enumerable: \"56\"", fn -> encode!(instance, input) end + + input = :"42" + assert_raise ArgumentError, "Numbers not enumerable: :\"42\"", fn -> encode!(instance, input) end + end + end + end + + defmodule AdditionalDecodingScenarios do + @moduledoc false + use ExUnit.Case, async: true + + import SqidsTest.Shared + + test "decode/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.decode(sqids, "0") end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.decode(sqids, "0") end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.decode(sqids, "0") end + end + + test "decode!/2: invalid sqids" do + sqids = :no + assert_raise ArgumentError, "argument error: :no", fn -> Sqids.decode!(sqids, "0") end + + sqids = %{a: 55} + assert_raise ArgumentError, "argument error: %{a: 55}", fn -> Sqids.decode!(sqids, "0") end + + sqids = %{__struct__: No} + assert_raise ArgumentError, "argument error: %{__struct__: No}", fn -> Sqids.decode!(sqids, "0") end + end + + for access_type <- [:"Direct API", :"Using module"] do + test "#{access_type}: decode/2: id is not a string or valid UTF-8" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = ~c"555" + assert_raise ArgumentError, "Id is not a string: ~c\"555\"", fn -> decode(instance, input) end + + input = 10_432_345 + assert_raise ArgumentError, "Id is not a string: 10432345", fn -> decode(instance, input) end + + input = "0000" <> <<128>> + assert_raise ArgumentError, "Id is not utf8: <<48, 48, 48, 48, 128>>", fn -> decode(instance, input) end + end + + test "#{access_type}: decode!/2: id is not a string or valid UTF-8" do + {:ok, instance} = new_sqids(unquote(access_type)) + + input = ~c"555" + assert_raise ArgumentError, "Id is not a string: ~c\"555\"", fn -> decode!(instance, input) end + + input = 10_432_345 + assert_raise ArgumentError, "Id is not a string: 10432345", fn -> decode!(instance, input) end + + input = "0000" <> <<128>> + assert_raise ArgumentError, "Id is not utf8: <<48, 48, 48, 48, 128>>", fn -> decode!(instance, input) end + end + + test "#{access_type}: decode/2: id has unknown chars" do + {:ok, instance} = new_sqids(unquote(access_type), alphabet: "01234") + + input = "5" + assert decode(instance, input) === {:ok, []} + + input = "011015" + assert decode(instance, input) === {:ok, []} + + input = "011015143" + assert decode(instance, input) === {:ok, []} + + input = "000ë5" + assert decode(instance, input) === {:ok, []} + end + + test "#{access_type}: decode!/2: id has unknown chars" do + {:ok, instance} = new_sqids(unquote(access_type), alphabet: "01234") + + input = "5" + assert decode!(instance, input) === [] + + input = "011015" + assert decode!(instance, input) === [] - assert new_sqids(unquote(access_type), min_length: 256) == - {:error, {:min_length_not_an_integer_in_range, 256, range: 0..255}} + input = "011015143" + assert decode!(instance, input) === [] - assert new_sqids(unquote(access_type), min_length: "1") == - {:error, {:min_length_not_an_integer_in_range, "1", range: 0..255}} + input = "000ë5" + assert decode!(instance, input) === [] end end end