diff --git a/internal/smtp/http.go b/internal/smtp/http.go index 84a64f1..d610265 100644 --- a/internal/smtp/http.go +++ b/internal/smtp/http.go @@ -4,6 +4,7 @@ import ( "fmt" "net/http" "path" + "regexp" "strconv" "strings" @@ -90,9 +91,19 @@ func (h *smtpHTTPHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func (h *smtpHTTPHandler) handleMessageIndex(w http.ResponseWriter, r *http.Request) { - // TODO: Implement search function + var receivedMessages []*ReceivedMessage - receivedMessages := h.server.ReceivedMessages() + headerSearchRgx, err := extractSearchRegex(w, r.URL.Query(), "header") + if err != nil { + handlerutil.RespondWithErr(w, http.StatusBadRequest, err) + return + } + if headerSearchRgx == nil { + receivedMessages = h.server.ReceivedMessages() + } else { + // FIXME: This does not preserve the correct indexes! + receivedMessages = h.server.SearchByHeader(headerSearchRgx) + } messagesOut := make([]any, 0) @@ -141,7 +152,6 @@ func (h *smtpHTTPHandler) handleMessageBody(w http.ResponseWriter, r *http.Reque } func (h *smtpHTTPHandler) handleMultipartIndex(w http.ResponseWriter, r *http.Request, idx int) { - // TODO: Implement search function msg := h.retrieveMessage(w, idx) if msg == nil { return @@ -150,7 +160,18 @@ func (h *smtpHTTPHandler) handleMultipartIndex(w http.ResponseWriter, r *http.Re return } - multiparts := msg.Multiparts() + var multiparts []*ReceivedPart + headerSearchRgx, err := extractSearchRegex(w, r.URL.Query(), "header") + if err != nil { + handlerutil.RespondWithErr(w, http.StatusBadRequest, err) + return + } + if headerSearchRgx == nil { + multiparts = msg.Multiparts() + } else { + // FIXME: This does not preserve the correct indexes! + multiparts = msg.SearchPartsByHeader(headerSearchRgx) + } multipartsOut := make([]any, 0) @@ -271,3 +292,30 @@ func ensureIsMultipart(w http.ResponseWriter, msg *ReceivedMessage) bool { return false } + +// extractSearchRegex tries to extract a regular expression from the referenced +// query parameter. If no query parameter is given and otherwise no error has +// occurred, this function returns (nil, nil). +func extractSearchRegex( + w http.ResponseWriter, queryParams map[string][]string, paramName string, +) (*regexp.Regexp, error) { + searchParam, ok := queryParams[paramName] + if ok { + if len(searchParam) != 1 { + return nil, fmt.Errorf( + "Encountered multiple %q params", paramName, + ) + } + + re, err := regexp.Compile(searchParam[0]) + if err != nil { + return nil, fmt.Errorf( + "could not compile %q regex: %w", paramName, err, + ) + } + + return re, nil + } + + return nil, nil +} diff --git a/internal/smtp/message.go b/internal/smtp/message.go index 2223dc8..08ce2bc 100644 --- a/internal/smtp/message.go +++ b/internal/smtp/message.go @@ -9,6 +9,7 @@ import ( "mime/multipart" "net/mail" "net/textproto" + "regexp" "strings" "time" ) @@ -113,6 +114,37 @@ func NewReceivedMessage( return msg, nil } +// SearchPartsByHeader returns the list of all received multiparts that +// have at least one header matching the given regular expression. +// +// For details on how the matching is performed, please refer to the +// documentation for Server.SearchByHeader. +// +// If the message is not a multipart message, this returns nil. +// If no matching multiparts are found, this may return nil or an empty +// list. +func (m *ReceivedMessage) SearchPartsByHeader(re *regexp.Regexp) []*ReceivedPart { + if !m.IsMultipart() { + return nil + } + + multiparts := m.Multiparts() + + headerIdxList := make([]map[string][]string, len(multiparts)) + for i, v := range multiparts { + headerIdxList[i] = v.Headers() + } + + foundIndices := searchByHeaderCommon(headerIdxList, re) + + results := make([]*ReceivedPart, 0, len(foundIndices)) + for _, idx := range foundIndices { + results = append(results, multiparts[idx]) + } + + return results +} + // NewReceivedPart parses a MIME multipart part into a ReceivedPart struct. // // Incoming data is truncated after the given maximum message size. diff --git a/internal/smtp/search.go b/internal/smtp/search.go new file mode 100644 index 0000000..ed1ba32 --- /dev/null +++ b/internal/smtp/search.go @@ -0,0 +1,29 @@ +package smtp + +import ( + "fmt" + "regexp" +) + +func searchByHeaderCommon(headerIdxList []map[string][]string, re *regexp.Regexp) []int { + result := make([]int, 0, len(headerIdxList)) + + for idx, headers := range headerIdxList { + if anyHeaderMatches(headers, re) { + result = append(result, idx) + } + } + + return result +} + +func anyHeaderMatches(headers map[string][]string, re *regexp.Regexp) bool { + for k, v := range headers { + header := fmt.Sprintf("%s: %s", k, v) + if re.MatchString(header) { + return true + } + } + + return false +} diff --git a/internal/smtp/server.go b/internal/smtp/server.go index 5684f88..b041849 100644 --- a/internal/smtp/server.go +++ b/internal/smtp/server.go @@ -109,12 +109,31 @@ func (s *Server) ReceivedMessages() []*ReceivedMessage { // SearchByHeader returns the list of all received messages that have at // least one header matching the given regular expression. -func (s *Server) SearchByHeader(re *regexp.Regexp) []ReceivedMessage { +// +// Note that the regex is performed for each header value individually, +// including for multi-value headers. The header value is first serialized +// by concatenating it after the header name, colon and space. It is not +// being encoded as if for transport (e.g. quoted-printable), +// but concatenated as-is. +func (s *Server) SearchByHeader(re *regexp.Regexp) []*ReceivedMessage { s.mutex.RLock() defer s.mutex.RUnlock() - // TODO - panic("not implemented") + receivedMessages := s.ReceivedMessages() + + headerIdxList := make([]map[string][]string, len(receivedMessages)) + for i, v := range receivedMessages { + headerIdxList[i] = v.Headers() + } + + foundIndices := searchByHeaderCommon(headerIdxList, re) + + results := make([]*ReceivedMessage, 0, len(foundIndices)) + for _, idx := range foundIndices { + results = append(results, receivedMessages[idx]) + } + + return results } func newSession(server *Server, c *smtp.Conn) (smtp.Session, error) {