Skip to content

Commit

Permalink
Removing a few too many clonings (#51)
Browse files Browse the repository at this point in the history
* Config as ref
* Avoid cloning shard_id strings
  • Loading branch information
gr211 authored Nov 6, 2023
1 parent 05a978f commit 672c7cf
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 37 deletions.
21 changes: 14 additions & 7 deletions src/kinesis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use aws_sdk_kinesis::operation::get_shard_iterator::GetShardIteratorOutput;
use chrono::prelude::*;
use chrono::{DateTime, Utc};
use log::{debug, warn};
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::time::{sleep, Duration};
Expand All @@ -21,7 +22,7 @@ pub mod ticker;

#[async_trait]
pub trait IteratorProvider<K: KinesisClient>: Send + Sync + Clone {
fn get_config(&self) -> ShardProcessorConfig<K>;
fn get_config(&self) -> &ShardProcessorConfig<K>;

async fn get_iterator(&self) -> Result<GetShardIteratorOutput>;
}
Expand All @@ -39,7 +40,13 @@ where
self.seed_shards(tx_shard_iterator_progress.clone()).await?;

while let Some(res) = rx_shard_iterator_progress.recv().await {
let permit = self.get_config().semaphore.acquire_owned().await.unwrap();
let permit = self
.get_config()
.semaphore
.clone()
.acquire_owned()
.await
.unwrap();

let res_clone = res.clone();

Expand Down Expand Up @@ -91,7 +98,7 @@ where
}
}
None => {
if let Some(sender) = self.get_config().tx_ticker_updates {
if let Some(sender) = &self.get_config().tx_ticker_updates {
sender
.send(TickerMessage::RemoveShard(
self.get_config().shard_id.clone(),
Expand All @@ -109,7 +116,7 @@ where

debug!("ShardProcessor {} finished", self.get_config().shard_id);

if let Some(sender) = self.get_config().tx_ticker_updates {
if let Some(sender) = &self.get_config().tx_ticker_updates {
sender
.send(TickerMessage::RemoveShard(
self.get_config().shard_id.clone(),
Expand Down Expand Up @@ -164,7 +171,7 @@ where
tx_shard_iterator_progress: Sender<ShardIteratorProgress>,
) -> Result<()> {
let resp = self.get_config().client.get_records(shard_iterator).await?;
let tx_ticker_updates = self.get_config().tx_ticker_updates;
let tx_ticker_updates = &self.get_config().tx_ticker_updates;

let next_shard_iterator = resp.next_shard_iterator();

Expand All @@ -177,7 +184,7 @@ where
let datetime = *record.approximate_arrival_timestamp().unwrap();

RecordResult {
shard_id: self.get_config().shard_id,
shard_id: Arc::clone(&self.get_config().shard_id),
sequence_id: record.sequence_number().unwrap().into(),
partition_key: record.partition_key().unwrap_or("none").into(),
datetime,
Expand All @@ -194,7 +201,7 @@ where
if let Some(tx_ticker_updates) = tx_ticker_updates {
tx_ticker_updates
.send(TickerMessage::CountUpdate(ShardCountUpdate {
shard_id: self.get_config().shard_id.clone(),
shard_id: Arc::clone(&self.get_config().shard_id),
millis_behind,
nb_records,
}))
Expand Down
8 changes: 4 additions & 4 deletions src/kinesis/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ pub fn new(
config: ShardProcessorConfig {
client,
stream,
shard_id,
shard_id: Arc::new(shard_id),
to_datetime,
semaphore,
tx_records,
Expand All @@ -55,7 +55,7 @@ pub fn new(
config: ShardProcessorConfig {
client,
stream,
shard_id,
shard_id: Arc::new(shard_id),
to_datetime,
semaphore,
tx_records,
Expand All @@ -71,7 +71,7 @@ pub async fn get_latest_iterator<T, K: KinesisClient>(
where
T: IteratorProvider<K>,
{
latest(&iterator_provider.get_config()).iterator().await
latest(iterator_provider.get_config()).iterator().await
}

pub async fn get_iterator_since<T, K: KinesisClient>(
Expand All @@ -81,7 +81,7 @@ pub async fn get_iterator_since<T, K: KinesisClient>(
where
T: IteratorProvider<K>,
{
at_sequence(&iterator_provider.get_config(), starting_sequence_number)
at_sequence(iterator_provider.get_config(), starting_sequence_number)
.iterator()
.await
}
Expand Down
12 changes: 6 additions & 6 deletions src/kinesis/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ pub enum ProcessError {

#[derive(Debug, Clone, PartialEq)]
pub struct RecordResult {
pub shard_id: String,
pub shard_id: Arc<String>,
pub sequence_id: String,
pub partition_key: String,
pub datetime: DateTime,
Expand All @@ -48,7 +48,7 @@ pub struct RecordResult {
pub struct ShardProcessorConfig<K: KinesisClient> {
pub client: K,
pub stream: String,
pub shard_id: String,
pub shard_id: Arc<String>,
pub to_datetime: Option<chrono::DateTime<Utc>>,
pub semaphore: Arc<Semaphore>,
pub tx_records: Sender<Result<ShardProcessorADT, ProcessError>>,
Expand All @@ -68,8 +68,8 @@ pub struct ShardProcessorAtTimestamp<K: KinesisClient> {

#[async_trait]
impl<K: KinesisClient> IteratorProvider<K> for ShardProcessorLatest<K> {
fn get_config(&self) -> ShardProcessorConfig<K> {
self.config.clone()
fn get_config(&self) -> &ShardProcessorConfig<K> {
&self.config
}

async fn get_iterator(&self) -> Result<GetShardIteratorOutput> {
Expand All @@ -79,8 +79,8 @@ impl<K: KinesisClient> IteratorProvider<K> for ShardProcessorLatest<K> {

#[async_trait]
impl<K: KinesisClient> IteratorProvider<K> for ShardProcessorAtTimestamp<K> {
fn get_config(&self) -> ShardProcessorConfig<K> {
self.config.clone()
fn get_config(&self) -> &ShardProcessorConfig<K> {
&self.config
}

async fn get_iterator(&self) -> Result<GetShardIteratorOutput> {
Expand Down
24 changes: 12 additions & 12 deletions src/kinesis/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async fn seed_shards_test() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
Expand Down Expand Up @@ -80,7 +80,7 @@ async fn seed_shards_test_timestamp_in_future() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
Expand Down Expand Up @@ -111,7 +111,7 @@ async fn produced_record_is_processed() {
config: ShardProcessorConfig {
client: client.clone(),
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
Expand All @@ -128,7 +128,7 @@ async fn produced_record_is_processed() {
assert_eq!(
ticker_update,
TickerMessage::CountUpdate(ShardCountUpdate {
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
millis_behind: 1000,
nb_records: 1
})
Expand Down Expand Up @@ -160,7 +160,7 @@ async fn beyond_to_timestamp_is_received() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: Some(to_datetime),
semaphore,
tx_records,
Expand All @@ -175,7 +175,7 @@ async fn beyond_to_timestamp_is_received() {
assert_eq!(
ticker_update,
TickerMessage::CountUpdate(ShardCountUpdate {
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
millis_behind: 1000,
nb_records: 1
})
Expand All @@ -202,7 +202,7 @@ async fn has_records_beyond_end_ts_when_has_end_ts() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: Some(to_datetime),
semaphore,
tx_records,
Expand All @@ -214,7 +214,7 @@ async fn has_records_beyond_end_ts_when_has_end_ts() {
assert!(processor.has_records_beyond_end_ts(&records));

let record1 = RecordResult {
shard_id: "shard_id".to_string(),
shard_id: Arc::new("shard_id".to_string()),
sequence_id: "sequence_id".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1000),
Expand All @@ -232,7 +232,7 @@ async fn has_records_beyond_end_ts_when_has_end_ts() {
let future_ts = to_datetime.add(chrono::Duration::days(1));

let record2 = RecordResult {
shard_id: "shard_id".to_string(),
shard_id: Arc::new("shard_id".to_string()),
sequence_id: "sequence_id".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_millis(future_ts.timestamp_millis()),
Expand Down Expand Up @@ -263,7 +263,7 @@ async fn has_records_beyond_end_ts_when_no_end_ts() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore,
tx_records,
Expand All @@ -279,7 +279,7 @@ async fn has_records_beyond_end_ts_when_no_end_ts() {
);

let record = RecordResult {
shard_id: "shard_id".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
sequence_id: "sequence_id".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1000),
Expand Down Expand Up @@ -308,7 +308,7 @@ async fn handle_iterator_refresh_ok() {
config: ShardProcessorConfig {
client,
stream: "test".to_string(),
shard_id: "shardId-000000000000".to_string(),
shard_id: Arc::new("shardId-000000000000".to_string()),
to_datetime: None,
semaphore: Arc::new(Semaphore::new(10)),
tx_records: mpsc::channel::<Result<ShardProcessorADT, ProcessError>>(10).0,
Expand Down
8 changes: 4 additions & 4 deletions src/kinesis/ticker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ use crate::kinesis::ProcessError::Timeout;
#[derive(Debug, Clone, PartialEq)]
pub enum TickerMessage {
CountUpdate(ShardCountUpdate),
RemoveShard(String),
RemoveShard(Arc<String>),
}

#[derive(Debug, Clone, PartialEq)]
pub struct ShardCountUpdate {
pub shard_id: String,
pub shard_id: Arc<String>,
pub millis_behind: i64,
pub nb_records: usize,
}
Expand Down Expand Up @@ -60,7 +60,7 @@ impl Ticker {
let counts = counts.deref_mut();
match res {
TickerMessage::CountUpdate(res) => {
counts.insert(res.shard_id.clone(), res.millis_behind);
counts.insert(res.shard_id.to_string(), res.millis_behind);

if res.nb_records > 0 {
let mut last_ts = self.last_ts.lock().await;
Expand All @@ -69,7 +69,7 @@ impl Ticker {
}
}
TickerMessage::RemoveShard(shard_id) => {
counts.remove(&shard_id);
counts.remove(shard_id.as_str());
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions src/sink/console_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::*;
use crate::kinesis::models::ShardProcessorADT::{BeyondToTimestamp, Progress, Termination};
use crate::sink::console::ConsoleSink;
use aws_sdk_kinesis::primitives::DateTime;
use std::sync::Arc;
use tokio::sync::mpsc;

#[test]
Expand Down Expand Up @@ -58,7 +59,7 @@ fn format_outputs_base64() {
let input = b"Hello \xF0\x90\x80World";

let record = RecordResult {
shard_id: "shard_id".to_string(),
shard_id: Arc::new("".to_string()),
sequence_id: "sequence_id".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1_000_000_i64),
Expand All @@ -81,7 +82,7 @@ fn format_outputs_no_base64() {
let input = b"Hello \xF0\x90\x80World";

let record = RecordResult {
shard_id: "shard_id".to_string(),
shard_id: Arc::new("shard_id".to_string()),
sequence_id: "sequence_id".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1_000_000_i64),
Expand Down Expand Up @@ -160,7 +161,7 @@ async fn expect_split() {
tokio::spawn(async move {
tx_records_clone
.send(Ok(Progress(vec![RecordResult {
shard_id: "".to_string(),
shard_id: Arc::new("".to_string()),
sequence_id: "".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1_000_000_i64),
Expand Down
3 changes: 2 additions & 1 deletion src/sink/file_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::sink::Sink;
use aws_sdk_kinesis::primitives::DateTime;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::mpsc;

#[tokio::test]
Expand All @@ -24,7 +25,7 @@ async fn file_sink_ok() {
tokio::spawn(async move {
tx_records_clone
.send(Ok(Progress(vec![RecordResult {
shard_id: "".to_string(),
shard_id: Arc::new("".to_string()),
sequence_id: "".to_string(),
partition_key: "partition_key".to_string(),
datetime: DateTime::from_secs(1_000_000_i64),
Expand Down

0 comments on commit 672c7cf

Please sign in to comment.