Skip to content

Commit

Permalink
feat: automatically quarantine project via task (pypi#17412)
Browse files Browse the repository at this point in the history
  • Loading branch information
miketheman authored Jan 15, 2025
1 parent 7c2bf39 commit f46c35f
Show file tree
Hide file tree
Showing 11 changed files with 501 additions and 21 deletions.
3 changes: 3 additions & 0 deletions dev/environment
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,6 @@ HCAPTCHA_SECRET_KEY=0x0000000000000000000000000000000000000000
# HELPSCOUT_WAREHOUSE_APP_SECRET="an insecure helpscout app secret"
# HELPSCOUT_WAREHOUSE_MAILBOX_ID=123456789
HELPDESK_BACKEND="warehouse.helpdesk.services.ConsoleHelpDeskService"

# HELPDESK_NOTIFICATION_SERVICE_URL="https://..."
HELPDESK_NOTIFICATION_BACKEND="warehouse.helpdesk.services.ConsoleAdminNotificationService"
15 changes: 14 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from warehouse.email import services as email_services
from warehouse.email.interfaces import IEmailSender
from warehouse.helpdesk import services as helpdesk_services
from warehouse.helpdesk.interfaces import IHelpDeskService
from warehouse.helpdesk.interfaces import IAdminNotificationService, IHelpDeskService
from warehouse.macaroons import services as macaroon_services
from warehouse.macaroons.interfaces import IMacaroonService
from warehouse.metrics import IMetricsService
Expand Down Expand Up @@ -183,6 +183,7 @@ def pyramid_services(
integrity_service,
macaroon_service,
helpdesk_service,
notification_service,
):
services = _Services()

Expand All @@ -205,6 +206,7 @@ def pyramid_services(
services.register_service(integrity_service, IIntegrityService, None)
services.register_service(macaroon_service, IMacaroonService, None, name="")
services.register_service(helpdesk_service, IHelpDeskService, None)
services.register_service(notification_service, IAdminNotificationService)

return services

Expand All @@ -230,6 +232,11 @@ def pyramid_request(pyramid_services, jinja, remote_addr, remote_addr_hashed):
dummy_request.task = pretend.call_recorder(
lambda *a, **kw: dummy_request._task_stub
)
dummy_request.log = pretend.stub(
bind=pretend.call_recorder(lambda *args, **kwargs: dummy_request.log),
info=pretend.call_recorder(lambda *args, **kwargs: None),
error=pretend.call_recorder(lambda *args, **kwargs: None),
)

def localize(message, **kwargs):
ts = TranslationString(message, **kwargs)
Expand Down Expand Up @@ -339,6 +346,7 @@ def get_app_config(database, nondefaults=None):
"billing.api_version": "2020-08-27",
"mail.backend": "warehouse.email.services.SMTPEmailSender",
"helpdesk.backend": "warehouse.helpdesk.services.ConsoleHelpDeskService",
"helpdesk.notification_backend": "warehouse.helpdesk.services.ConsoleHelpDeskService", # noqa: E501
"files.url": "http://localhost:7000/",
"archive_files.url": "http://localhost:7000/archive",
"sessions.secret": "123456",
Expand Down Expand Up @@ -616,6 +624,11 @@ def helpdesk_service():
return helpdesk_services.ConsoleHelpDeskService()


@pytest.fixture
def notification_service():
return helpdesk_services.ConsoleAdminNotificationService()


class QueryRecorder:
def __init__(self):
self.queries = []
Expand Down
21 changes: 15 additions & 6 deletions tests/unit/helpdesk/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,29 @@
import pretend

from warehouse.helpdesk import includeme
from warehouse.helpdesk.interfaces import IHelpDeskService
from warehouse.helpdesk.interfaces import IAdminNotificationService, IHelpDeskService


def test_includeme():
helpdesk_class = pretend.stub(create_service=pretend.stub())
dummy_klass = pretend.stub(create_service=pretend.stub())
config = pretend.stub(
registry=pretend.stub(settings={"helpdesk.backend": "tests.CustomBackend"}),
maybe_dotted=pretend.call_recorder(lambda n: helpdesk_class),
registry=pretend.stub(
settings={
"helpdesk.backend": "test.HelpDeskService",
"helpdesk.notification_backend": "test.NotificationService",
}
),
maybe_dotted=pretend.call_recorder(lambda n: dummy_klass),
register_service_factory=pretend.call_recorder(lambda s, i, **kw: None),
)

includeme(config)

assert config.maybe_dotted.calls == [pretend.call("tests.CustomBackend")]
assert config.maybe_dotted.calls == [
pretend.call("test.HelpDeskService"),
pretend.call("test.NotificationService"),
]
assert config.register_service_factory.calls == [
pretend.call(helpdesk_class.create_service, IHelpDeskService)
pretend.call(dummy_klass.create_service, IHelpDeskService),
pretend.call(dummy_klass.create_service, IAdminNotificationService),
]
76 changes: 74 additions & 2 deletions tests/unit/helpdesk/test_services.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,13 @@
from pyramid_retry import RetryableException
from zope.interface.verify import verifyClass

from warehouse.helpdesk.interfaces import IHelpDeskService
from warehouse.helpdesk.services import ConsoleHelpDeskService, HelpScoutService
from warehouse.helpdesk.interfaces import IAdminNotificationService, IHelpDeskService
from warehouse.helpdesk.services import (
ConsoleAdminNotificationService,
ConsoleHelpDeskService,
HelpScoutService,
SlackAdminNotificationService,
)


@pytest.mark.parametrize("service_class", [ConsoleHelpDeskService, HelpScoutService])
Expand Down Expand Up @@ -217,3 +222,70 @@ def test_add_tag_with_duplicate(self):

# No PUT call should be made
assert len(responses.calls) == 1


@pytest.mark.parametrize(
"service_class", [ConsoleAdminNotificationService, SlackAdminNotificationService]
)
class TestAdminNotificationService:
"""Common tests for the service interface."""

def test_verify_service_class(self, service_class):
assert verifyClass(IAdminNotificationService, service_class)

def test_create_service(self, service_class):
context = None
request = pretend.stub(
http=pretend.stub(),
log=pretend.stub(
debug=pretend.call_recorder(lambda msg: None),
),
registry=pretend.stub(
settings={
"helpdesk.notification_service_url": "https://webhook.example/1234",
}
),
)

service = service_class.create_service(context, request)
assert isinstance(service, service_class)


class TestConsoleAdminNotificationService:
def test_send_notification(self, capsys):
service = ConsoleAdminNotificationService()

service.send_notification(payload={"text": "Hello, World!"})

captured = capsys.readouterr()

expected = dedent(
"""\
Webhook notification sent
payload:
{'text': 'Hello, World!'}
"""
)
assert captured.out == expected


class TestSlackAdminNotificationService:
@responses.activate
def test_send_notification(self):
responses.add(
responses.POST,
"https://webhook.example/1234",
json={"ok": True},
)

service = SlackAdminNotificationService(
session=requests.Session(),
webhook_url="https://webhook.example/1234",
)

service.send_notification(payload={"text": "Hello, World!"})

assert len(responses.calls) == 1
post_call = responses.calls[0]
assert post_call.request.url == "https://webhook.example/1234"
assert post_call.response.json() == {"ok": True}
169 changes: 165 additions & 4 deletions tests/unit/observations/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@

from warehouse.observations.models import ObservationKind
from warehouse.observations.tasks import (
execute_observation_report,
evaluate_project_for_quarantine,
react_to_observation_created,
report_observation_to_helpscout,
)
from warehouse.packaging.models import LifecycleStatus

from ...common.db.accounts import UserFactory
from ...common.db.packaging import ProjectFactory, RoleFactory
from ...common.db.packaging import ProjectFactory, ReleaseFactory, RoleFactory


def test_execute_observation_report(app_config):
Expand All @@ -29,9 +31,9 @@ def test_execute_observation_report(app_config):
observation = pretend.stub(id=pretend.stub())
session = pretend.stub(info={"warehouse.observations.new": {observation}})

execute_observation_report(app_config, session)
react_to_observation_created(app_config, session)

assert _delay.calls == [pretend.call(observation.id)]
assert _delay.calls == [pretend.call(observation.id), pretend.call(observation.id)]


@pytest.mark.parametrize(
Expand Down Expand Up @@ -75,3 +77,162 @@ def test_report_observation_to_helpscout(

# If it's not supposed to report, then we shouldn't have called the service
assert bool(hs_svc_spy.calls) == reports


class TestAutoQuarantineProject:
def test_non_malware_observation_does_not_quarantine(self, db_request):
dummy_task = pretend.stub(name="dummy_task")
user = UserFactory.create()
db_request.user = user
project = ProjectFactory.create()

observation = project.record_observation(
request=db_request,
kind=ObservationKind.IsDependencyConfusion,
summary="Project Observation",
payload={},
actor=user,
)
# Need to flush the session to ensure the Observation has an ID
db_request.db.flush()

evaluate_project_for_quarantine(dummy_task, db_request, observation.id)

assert project.lifecycle_status != LifecycleStatus.QuarantineEnter
assert db_request.log.info.calls == [
pretend.call("ObservationKind is not IsMalware. Not quarantining.")
]

def test_already_quarantined_project_does_not_do_anything(self, db_request):
dummy_task = pretend.stub(name="dummy_task")
user = UserFactory.create()
db_request.user = user
project = ProjectFactory.create(
lifecycle_status=LifecycleStatus.QuarantineEnter
)

observation = project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=user,
)
# Need to flush the session to ensure the Observation has an ID
db_request.db.flush()

evaluate_project_for_quarantine(dummy_task, db_request, observation.id)

assert project.lifecycle_status == LifecycleStatus.QuarantineEnter
assert db_request.log.info.calls == [
pretend.call("Project is already quarantined. No change needed.")
]

def test_not_enough_observers_does_not_quarantine(self, db_request):
dummy_task = pretend.stub(name="dummy_task")
user = UserFactory.create()
db_request.user = user
project = ProjectFactory.create()

observation = project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=user,
)
# Need to flush the session to ensure the Observation has an ID
db_request.db.flush()

evaluate_project_for_quarantine(dummy_task, db_request, observation.id)

assert project.lifecycle_status != LifecycleStatus.QuarantineEnter
assert db_request.log.info.calls == [
pretend.call("Project has fewer than 2 observers. Not quarantining.")
]

def test_no_observer_observers_does_not_quarantine(self, db_request):
dummy_task = pretend.stub(name="dummy_task")
user = UserFactory.create()
db_request.user = user
project = ProjectFactory.create()

another_user = UserFactory.create()

# Record 2 observations, but neither are from an observer
project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=user,
)
observation = project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=another_user,
)
# Need to flush the session to ensure the Observations has an ID
db_request.db.flush()

evaluate_project_for_quarantine(dummy_task, db_request, observation.id)

assert project.lifecycle_status != LifecycleStatus.QuarantineEnter
assert db_request.log.info.calls == [
pretend.call(
"Project has no `User.is_observer` Observers. Not quarantining."
)
]

def test_quarantines_project(self, db_request, notification_service, monkeypatch):
"""
Satisfies criteria for auto-quarantine:
- 2 observations
- from different observers
- one of which is an Observer
"""
dummy_task = pretend.stub(name="dummy_task")
user = UserFactory.create(is_observer=True)
project = ProjectFactory.create()
# Needs a release to be able to quarantine
ReleaseFactory.create(project=project)

another_user = UserFactory.create()

db_request.route_url = pretend.call_recorder(
lambda *args, **kw: "/project/spam/"
)
db_request.user = user

# Record 2 observations, one from an observer
project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=user,
)
observation = project.record_observation(
request=db_request,
kind=ObservationKind.IsMalware,
summary="Project Observation",
payload={},
actor=another_user,
)
# Need to flush the session to ensure the Observation has an ID
db_request.db.flush()

ns_svc_spy = pretend.call_recorder(lambda *args, **kwargs: None)
monkeypatch.setattr(notification_service, "send_notification", ns_svc_spy)

evaluate_project_for_quarantine(dummy_task, db_request, observation.id)

assert len(ns_svc_spy.calls) == 1
assert project.lifecycle_status == LifecycleStatus.QuarantineEnter
assert db_request.log.info.calls == [
pretend.call(
"Auto-quarantining project due to multiple malware observations."
),
]
9 changes: 9 additions & 0 deletions warehouse/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,15 @@ def configure(settings=None):
maybe_set(settings, "helpscout.app_id", "HELPSCOUT_WAREHOUSE_APP_ID")
maybe_set(settings, "helpscout.app_secret", "HELPSCOUT_WAREHOUSE_APP_SECRET")
maybe_set(settings, "helpscout.mailbox_id", "HELPSCOUT_WAREHOUSE_MAILBOX_ID")
# Admin notification service settings
maybe_set(
settings, "helpdesk.notification_backend", "HELPDESK_NOTIFICATION_BACKEND"
)
maybe_set(
settings,
"helpdesk.notification_service_url",
"HELPDESK_NOTIFICATION_SERVICE_URL",
)

# Configure our ratelimiters
maybe_set(
Expand Down
Loading

0 comments on commit f46c35f

Please sign in to comment.