Skip to content

Commit

Permalink
Greenbids fix geolookup: fetch from official MaxMind URL + mock dbRea…
Browse files Browse the repository at this point in the history
…der UT (#3626)
  • Loading branch information
EvgeniiMunin authored Jan 17, 2025
1 parent b96c137 commit e9417e3
Show file tree
Hide file tree
Showing 10 changed files with 233 additions and 54 deletions.
Original file line number Diff line number Diff line change
@@ -1,52 +1,128 @@
package org.prebid.server.hooks.modules.greenbids.real.time.data.config;

import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.file.FileSystem;
import io.vertx.core.http.HttpClientOptions;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
import com.maxmind.db.Reader;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import com.maxmind.geoip2.DatabaseReader;
import io.vertx.core.file.OpenOptions;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.RequestOptions;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveInputStream;
import org.prebid.server.exception.PreBidException;
import org.prebid.server.log.Logger;
import org.prebid.server.log.LoggerFactory;
import org.prebid.server.vertx.Initializable;

import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.atomic.AtomicReference;
import java.util.zip.GZIPInputStream;

public class DatabaseReaderFactory implements Initializable {

private final String geoLiteCountryUrl;
private static final Logger logger = LoggerFactory.getLogger(DatabaseReaderFactory.class);

private final GreenbidsRealTimeDataProperties properties;

private final Vertx vertx;

private final AtomicReference<DatabaseReader> databaseReaderRef = new AtomicReference<>();

public DatabaseReaderFactory(String geoLitCountryUrl, Vertx vertx) {
this.geoLiteCountryUrl = geoLitCountryUrl;
private final FileSystem fileSystem;

public DatabaseReaderFactory(GreenbidsRealTimeDataProperties properties, Vertx vertx) {
this.properties = properties;
this.vertx = vertx;
this.fileSystem = vertx.fileSystem();
}

@Override
public void initialize(Promise<Void> initializePromise) {
downloadAndExtract()
.onSuccess(databaseReaderRef::set)
.<Void>mapEmpty()
.onComplete(initializePromise);
}

private Future<DatabaseReader> downloadAndExtract() {
final String downloadUrl = properties.getGeoLiteCountryPath();
final String tmpPath = properties.getTmpPath();
return downloadFile(downloadUrl, tmpPath)
.compose(ignored -> vertx.executeBlocking(() -> extractMMDB(tmpPath)))
.onComplete(ar -> removeFile(tmpPath));
}

private Future<Void> downloadFile(String downloadUrl, String tmpPath) {
return fileSystem.open(tmpPath, new OpenOptions())
.compose(tmpFile -> sendHttpRequest(downloadUrl)
.onFailure(ignore -> tmpFile.close())
.compose(response -> response.pipeTo(tmpFile)));
}

private Future<HttpClientResponse> sendHttpRequest(String url) {
final RequestOptions options = new RequestOptions()
.setFollowRedirects(true)
.setMethod(HttpMethod.GET)
.setTimeout(properties.getTimeoutMs())
.setAbsoluteURI(url);

final HttpClientOptions httpClientOptions = new HttpClientOptions()
.setConnectTimeout(properties.getTimeoutMs().intValue())
.setMaxRedirects(properties.getMaxRedirects());

vertx.executeBlocking(() -> {
try {
final URL url = new URL(geoLiteCountryUrl);
final Path databasePath = Files.createTempFile("GeoLite2-Country", ".mmdb");
return vertx.createHttpClient(httpClientOptions).request(options)
.compose(HttpClientRequest::send)
.map(this::validateResponse);
}

private HttpClientResponse validateResponse(HttpClientResponse response) {
final int statusCode = response.statusCode();
if (statusCode != HttpResponseStatus.OK.code()) {
throw new PreBidException("Got unexpected response from server with status code %s and message %s"
.formatted(statusCode, response.statusMessage()));
}
return response;
}

try (InputStream inputStream = url.openStream();
FileOutputStream outputStream = new FileOutputStream(databasePath.toFile())) {
inputStream.transferTo(outputStream);
private DatabaseReader extractMMDB(String tarGzPath) {
try (GZIPInputStream gis = new GZIPInputStream(Files.newInputStream(Path.of(tarGzPath)));
TarArchiveInputStream tarInput = new TarArchiveInputStream(gis)) {

TarArchiveEntry currentEntry;
boolean hasDatabaseFile = false;
while ((currentEntry = tarInput.getNextTarEntry()) != null) {
if (currentEntry.getName().contains("GeoLite2-Country.mmdb")) {
hasDatabaseFile = true;
break;
}
}

if (!hasDatabaseFile) {
throw new PreBidException("GeoLite2-Country.mmdb not found in the archive");
}

return new DatabaseReader.Builder(tarInput)
.fileMode(Reader.FileMode.MEMORY).build();
} catch (IOException e) {
throw new PreBidException("Failed to extract MMDB file", e);
}
}

databaseReaderRef.set(new DatabaseReader.Builder(databasePath.toFile()).build());
} catch (IOException e) {
throw new PreBidException("Failed to initialize DatabaseReader from URL", e);
private void removeFile(String filePath) {
fileSystem.exists(filePath).onSuccess(exists -> {
if (exists) {
fileSystem.delete(filePath)
.onFailure(err -> logger.error("Failed to remove file {}", filePath, err));
}
return null;
}).<Void>mapEmpty()
.onComplete(initializePromise);
});
}

public DatabaseReader getDatabaseReader() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import com.google.cloud.storage.Storage;
import com.google.cloud.storage.StorageOptions;
import io.vertx.core.Vertx;
import org.prebid.server.geolocation.CountryCodeMapper;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.filter.ThrottlingThresholds;
import org.prebid.server.hooks.modules.greenbids.real.time.data.core.ThrottlingThresholdsFactory;
import org.prebid.server.hooks.modules.greenbids.real.time.data.core.GreenbidsInferenceDataService;
Expand Down Expand Up @@ -32,13 +33,15 @@ public class GreenbidsRealTimeDataConfiguration {

@Bean
DatabaseReaderFactory databaseReaderFactory(GreenbidsRealTimeDataProperties properties, Vertx vertx) {
return new DatabaseReaderFactory(properties.getGeoLiteCountryPath(), vertx);
return new DatabaseReaderFactory(properties, vertx);
}

@Bean
GreenbidsInferenceDataService greenbidsInferenceDataService(DatabaseReaderFactory databaseReaderFactory) {
GreenbidsInferenceDataService greenbidsInferenceDataService(DatabaseReaderFactory databaseReaderFactory,
CountryCodeMapper countryCodeMapper) {

return new GreenbidsInferenceDataService(
databaseReaderFactory, ObjectMapperProvider.mapper());
databaseReaderFactory, ObjectMapperProvider.mapper(), countryCodeMapper);
}

@Bean
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ public class GreenbidsRealTimeDataProperties {

String geoLiteCountryPath;

String tmpPath;

String gcsBucketName;

Integer cacheExpirationMinutes;

String onnxModelCacheKeyPrefix;

String thresholdsCacheKeyPrefix;

Long timeoutMs;

Integer maxRedirects;
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.iab.openrtb.request.BidRequest;
import com.iab.openrtb.request.Device;
import com.iab.openrtb.request.Geo;
import com.iab.openrtb.request.Imp;
import com.maxmind.geoip2.DatabaseReader;
import com.maxmind.geoip2.exception.GeoIp2Exception;
import com.maxmind.geoip2.model.CountryResponse;
import com.maxmind.geoip2.record.Country;
import org.apache.commons.lang3.StringUtils;
import org.prebid.server.exception.PreBidException;
import org.prebid.server.geolocation.CountryCodeMapper;
import org.prebid.server.hooks.modules.greenbids.real.time.data.config.DatabaseReaderFactory;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.data.ThrottlingMessage;
import org.prebid.server.proto.openrtb.ext.request.ExtImpPrebid;
Expand All @@ -25,6 +27,7 @@
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -35,9 +38,14 @@ public class GreenbidsInferenceDataService {

private final ObjectMapper mapper;

public GreenbidsInferenceDataService(DatabaseReaderFactory dbReaderFactory, ObjectMapper mapper) {
private final CountryCodeMapper countryCodeMapper;

public GreenbidsInferenceDataService(DatabaseReaderFactory dbReaderFactory,
ObjectMapper mapper,
CountryCodeMapper countryCodeMapper) {
this.databaseReaderFactory = Objects.requireNonNull(dbReaderFactory);
this.mapper = Objects.requireNonNull(mapper);
this.countryCodeMapper = Objects.requireNonNull(countryCodeMapper);
}

public List<ThrottlingMessage> extractThrottlingMessagesFromBidRequest(BidRequest bidRequest) {
Expand Down Expand Up @@ -86,23 +94,38 @@ private List<ThrottlingMessage> extractMessagesForImp(
final String ip = Optional.ofNullable(bidRequest.getDevice())
.map(Device::getIp)
.orElse(null);
final String countryFromIp = getCountry(ip);
final String country = Optional.ofNullable(bidRequest.getDevice())
.map(Device::getGeo)
.map(Geo::getCountry)
.map(countryCodeMapper::mapToAlpha2)
.map(GreenbidsInferenceDataService::getCountryNameFromAlpha2)
.filter(c -> !c.isEmpty())
.orElseGet(() -> getCountry(ip));

return createThrottlingMessages(
bidderNode,
impId,
greenbidsUserAgent,
countryFromIp,
country,
hostname,
hourBucket,
minuteQuadrant);
}

private String getCountry(String ip) {
if (ip == null) {
return null;
}
private static String getCountryNameFromAlpha2(String isoCode) {
return StringUtils.isBlank(isoCode)
? StringUtils.EMPTY
: new Locale(StringUtils.EMPTY, isoCode).getDisplayCountry();
}

private String getCountry(String ip) {
final DatabaseReader databaseReader = databaseReaderFactory.getDatabaseReader();
return ip != null && databaseReader != null
? getCountryFromIpUsingDatabase(databaseReader, ip)
: null;
}

private String getCountryFromIpUsingDatabase(DatabaseReader databaseReader, String ip) {
try {
final InetAddress inetAddress = InetAddress.getByName(ip);
final CountryResponse response = databaseReader.country(inetAddress);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.prebid.server.exception.PreBidException;
import org.prebid.server.geolocation.CountryCodeMapper;
import org.prebid.server.hooks.modules.greenbids.real.time.data.config.DatabaseReaderFactory;
import org.prebid.server.hooks.modules.greenbids.real.time.data.model.data.ThrottlingMessage;
import org.prebid.server.hooks.modules.greenbids.real.time.data.util.TestBidRequestProvider;
Expand Down Expand Up @@ -50,16 +51,20 @@ public class GreenbidsInferenceDataServiceTest {
@Mock
private Country country;

@Mock
private CountryCodeMapper countryCodeMapper;

private GreenbidsInferenceDataService target;

@BeforeEach
public void setUp() {
when(databaseReaderFactory.getDatabaseReader()).thenReturn(databaseReader);
target = new GreenbidsInferenceDataService(databaseReaderFactory, TestBidRequestProvider.MAPPER);
target = new GreenbidsInferenceDataService(
databaseReaderFactory, TestBidRequestProvider.MAPPER, countryCodeMapper);
}

@Test
public void extractThrottlingMessagesFromBidRequestShouldReturnValidThrottlingMessages()
public void extractThrottlingMessagesFromBidRequestShouldReturnValidThrottlingMessagesWhenGeoIsNull()
throws IOException, GeoIp2Exception {
// given
final Banner banner = givenBanner();
Expand All @@ -79,20 +84,57 @@ public void extractThrottlingMessagesFromBidRequestShouldReturnValidThrottlingMe

when(databaseReader.country(any(InetAddress.class))).thenReturn(countryResponse);
when(countryResponse.getCountry()).thenReturn(country);
when(country.getName()).thenReturn("US");
when(country.getName()).thenReturn("United States");

// when
final List<ThrottlingMessage> throttlingMessages = target.extractThrottlingMessagesFromBidRequest(bidRequest);

// then
assertThat(throttlingMessages).isNotEmpty();
assertThat(throttlingMessages.getFirst().getBidder()).isEqualTo("rubicon");
assertThat(throttlingMessages.get(1).getBidder()).isEqualTo("appnexus");
assertThat(throttlingMessages.getLast().getBidder()).isEqualTo("pubmatic");
assertThat(throttlingMessages)
.extracting(ThrottlingMessage::getBidder)
.containsExactly("rubicon", "appnexus", "pubmatic");

throttlingMessages.forEach(message -> {
assertThat(message.getAdUnitCode()).isEqualTo("adunitcodevalue");
assertThat(message.getCountry()).isEqualTo("US");
assertThat(message.getCountry()).isEqualTo("United States");
assertThat(message.getHostname()).isEqualTo("www.leparisien.fr");
assertThat(message.getDevice()).isEqualTo("PC");
assertThat(message.getHourBucket()).isEqualTo(String.valueOf(expectedHourBucket));
assertThat(message.getMinuteQuadrant()).isEqualTo(String.valueOf(expectedMinuteQuadrant));
});
}

@Test
public void extractThrottlingMessagesFromBidRequestShouldReturnValidThrottlingMessagesWhenGeoDefined() {
// given
final Banner banner = givenBanner();
final Imp imp = Imp.builder()
.id("adunitcodevalue")
.ext(givenImpExt())
.banner(banner)
.build();
final Device device = givenDevice(identity(), "FRA");
final BidRequest bidRequest = givenBidRequest(request -> request, List.of(imp), device, null);

final ZonedDateTime timestamp = ZonedDateTime.now(ZoneId.of("UTC"));
final Integer expectedHourBucket = timestamp.getHour();
final Integer expectedMinuteQuadrant = (timestamp.getMinute() / 15) + 1;

when(countryCodeMapper.mapToAlpha2("FRA")).thenReturn("FR");

// when
final List<ThrottlingMessage> throttlingMessages = target.extractThrottlingMessagesFromBidRequest(bidRequest);

// then
assertThat(throttlingMessages).isNotEmpty();
assertThat(throttlingMessages)
.extracting(ThrottlingMessage::getBidder)
.containsExactly("rubicon", "appnexus", "pubmatic");

throttlingMessages.forEach(message -> {
assertThat(message.getAdUnitCode()).isEqualTo("adunitcodevalue");
assertThat(message.getCountry()).isEqualTo("France");
assertThat(message.getHostname()).isEqualTo("www.leparisien.fr");
assertThat(message.getDevice()).isEqualTo("PC");
assertThat(message.getHourBucket()).isEqualTo(String.valueOf(expectedHourBucket));
Expand Down Expand Up @@ -121,10 +163,9 @@ public void extractThrottlingMessagesFromBidRequestShouldHandleMissingIp() {

// then
assertThat(throttlingMessages).isNotEmpty();

assertThat(throttlingMessages.getFirst().getBidder()).isEqualTo("rubicon");
assertThat(throttlingMessages.get(1).getBidder()).isEqualTo("appnexus");
assertThat(throttlingMessages.getLast().getBidder()).isEqualTo("pubmatic");
assertThat(throttlingMessages)
.extracting(ThrottlingMessage::getBidder)
.containsExactly("rubicon", "appnexus", "pubmatic");

throttlingMessages.forEach(message -> {
assertThat(message.getAdUnitCode()).isEqualTo("adunitcodevalue");
Expand Down
Loading

0 comments on commit e9417e3

Please sign in to comment.