You are given:
A very large dataset distributed across multiple machines (typically 10 workers)
Pre-built interface functions: send(worker_id, data) and recv()
Each machine has a portion of the dataset stored locally
Implement a distributed system to find the mode (most frequent value) of the entire dataset.
Critical Constraint: Data transmission speed is the system's bottleneck. You must avoid having all machines simultaneously send data to a single machine, as this creates severe reception delays.
Below is the interface you must use:
# Send data to a specific worker
def send(worker_id, data):
pass
# Receive data from any worker
def recv():
pass
Dataset characteristics:
The dataset contains mostly unique integers
One integer appears twice (duplicates) but on different machines
All other values appear only once
Environment: Python
Worker count: 10 machines (configurable via WORKER_NUM)
Current worker ID: Available as worker_id
Your solution should:
Use the provided interface — do not modify or reimplement send() and recv()
Avoid bottlenecks — distribute communication load evenly across all workers
Maximize parallelism — workers should send and receive simultaneously when possible
Handle latency — account for both send and receive operation latency
Return the mode — the value that appears most frequently in the entire dataset
Bandwidth management: The modulo-based distribution strategy is crucial to avoid bottlenecks
Clarify with interviewer: What should be returned if multiple values have the same highest count (ties)? Should we return:
Any one of them?
The smallest/largest value?
All values with the maximum count?
Top K vs Top 1: After repartitioning by key (Phase 2), sending only top 1 from each worker is sufficient to recover the global mode because each worker holds complete counts for a disjoint key set. Without this repartition, per-worker top-1 is not sufficient when the duplicated value is split across machines. If the task requires the K most frequent values or deterministic tie-breaking, send top K (or encode a tie-break rule).
Each worker independently processes its local data partition.
from collections import Counter
def phase1_local_count(local_data):
"""
Each worker counts frequencies in its local partition.
Args:
local_data: List of integers stored on this worker
Returns:
Counter object: {value: count}
"""
return Counter(local_data)
Example:
Worker 0 has [1, 2, 3, 5, 10] → {1:1, 2:1, 3:1, 5:1, 10:1}
Worker 1 has [2, 4, 6, 8] → {2:1, 4:1, 6:1, 8:1}
Each worker redistributes its local counts to the appropriate target workers using modulo hashing. This ensures:
All identical keys end up on the same worker
All workers send and receive simultaneously (no bottleneck)
def phase2_shuffle_counts(local_counter, worker_id, num_workers):
"""
Redistribute counts to workers based on key % num_workers.
This ensures that the same key always goes to the same worker,
and communication is balanced across all workers.
Args:
local_counter: Counter from phase 1
worker_id: Current worker's ID
num_workers: Total number of workers
"""
# Prepare buckets for each target worker
buckets = [[] for _ in range(num_workers)]
# Distribute keys based on modulo
for key, count in local_counter.items():
target_worker = key % num_workers
buckets[target_worker].append((key, count))
# Send to all other workers
for target_id in range(num_workers):
if target_id != worker_id:
send(target_id, buckets[target_id])
# Receive from all other workers
received_data = []
for _ in range(num_workers - 1):
data = recv()
received_data.extend(data)
# Add local bucket (no network transfer needed)
received_data.extend(buckets[worker_id])
return received_data
Why this works:
Worker i receives all keys where key % num_workers == i
If value X appears on workers 3 and 7, both will send their count for X to the same target worker
All workers are sending and receiving simultaneously
Each worker aggregates all counts it received for its assigned keys and finds its local maximum.
def phase3_aggregate_local(received_data):
"""
Aggregate counts for all keys assigned to this worker.
Args:
received_data: List of (key, count) tuples from phase 2
Returns:
Counter with aggregated counts
Tuple of (key, count) for the most frequent value in this worker's partition
"""
# Aggregate all counts for the same key
aggregated = Counter()
for key, count in received_data:
aggregated[key] += count
# Get the most frequent value in this worker's partition
# Since keys are partitioned by modulo, the global mode must be
# one of the local top-1 values across all workers
if aggregated:
top_1 = aggregated.most_common(1)[0]
else:
top_1 = (None, 0) # Empty partition
return aggregated, top_1
Why top 1 is sufficient:
After Phase 2, each worker owns a disjoint set of keys (partitioned by key % num_workers)
Each worker has the complete, final count for all keys assigned to it
The global mode must be the maximum among all workers' local maximums
No need to send top K since keys don't overlap between workers
All workers send their local top 1 to worker 0, which computes the global mode.
def phase4_global_reduction(local_top_1, worker_id, num_workers):
"""
Gather top 1 from each worker to find global mode.
Args:
local_top_1: (key, count) tuple - this worker's most frequent value
worker_id: Current worker's ID
num_workers: Total number of workers
Returns:
Global mode (only on worker 0, None on others)
"""
if worker_id == 0:
# Start with local top 1
best_key, best_count = local_top_1
# Receive from all other workers (only 1 value each)
for _ in range(num_workers - 1):
remote_key, remote_count = recv()
# Deterministic tie-breaking: prefer smaller key when counts are equal
if remote_count > best_count or (remote_count == best_count and remote_key is not None and best_key is not None and remote_key < best_key):
best_key, best_count = remote_key, remote_count
return best_key
else:
# Send only top 1 to worker 0
send(0, local_top_1)
return None
Optimization: Only O(num_workers) values are sent in this phase (one per worker), making it extremely lightweight compared to sending top K.
Note: For larger-scale systems (100+ workers), consider using a tree-based reduction with O(log W) rounds to avoid a single receiver bottleneck. See "Real-World Considerations" section below.
def find_mode_distributed(local_data, worker_id, num_workers):
"""
Complete distributed mode finding algorithm.
Args:
local_data: List of integers on this worker
worker_id: Current worker's ID (0 to num_workers-1)
num_workers: Total number of workers
Returns:
The mode of the entire dataset (only on worker 0)
"""
# Phase 1: Count local frequencies
local_counter = phase1_local_count(local_data)
# Phase 2: Shuffle counts to appropriate workers
received_data = phase2_shuffle_counts(local_counter, worker_id, num_workers)
# Phase 3: Aggregate and find local top 1
aggregated, local_top_1 = phase3_aggregate_local(received_data)
# Phase 4: Global reduction to find mode
mode = phase4_global_reduction(local_top_1, worker_id, num_workers)
return mode
Complexity Analysis:
Time: O(N/W + unique_keys/W) per worker, where N = total data size, W = num_workers
Phase 1: O(N/W) to build Counter
Phase 2: O(unique_keys/W) to shuffle
Phase 3: O(unique_keys/W) to aggregate and find max
Phase 4: O(W) on worker 0, O(1) on others
Communication:
Phase 2: O(unique_keys/W) data sent per worker (the main bottleneck)
Phase 4: O(1) data per worker (only one (key, count) tuple)
Space: O(unique_keys/W) per worker
Note on When to Use Top K:
If the problem requires finding the K most frequent values, modify Phase 3 to return most_common(K) and Phase 4 to merge these lists
If ties are common and you need all values with the maximum count, send enough candidates to cover potential ties
For the standard mode problem (single most frequent value), top 1 is optimal
After implementing the mode solution, extend your system to find the median of the distributed dataset.
Unlike mode, the median requires understanding the global ordering of values. Use a distributed quickselect algorithm that iteratively narrows down the median's location without sorting the entire dataset.
The median is the value at position total_count / 2 when all values are sorted. The algorithm:
Builds a distributed histogram (value → count mapping)
Uses binary search with distributed counting to find the median
Reuse the mode-finding phases 1-2 to create a distributed histogram where each worker owns a range of values.
def build_distributed_histogram(local_data, worker_id, num_workers):
"""
Create distributed histogram using same approach as mode phases 1-2.
Returns:
local_histogram: Counter for values assigned to this worker
total_count: Total number of elements (computed collectively)
"""
# Phase 1: Local counting
local_counter = Counter(local_data)
# Phase 2: Shuffle to appropriate workers
received_data = phase2_shuffle_counts(local_counter, worker_id, num_workers)
# Aggregate
local_histogram = Counter()
for key, count in received_data:
local_histogram[key] += count
# Compute total count (all workers send to worker 0)
if worker_id == 0:
total_count = sum(count for count in local_histogram.values())
for _ in range(num_workers - 1):
remote_count = recv()
total_count += remote_count
# Broadcast total to all workers
for i in range(1, num_workers):
send(i, total_count)
return local_histogram, total_count
else:
my_count = sum(count for count in local_histogram.values())
send(0, my_count)
total_count = recv()
return local_histogram, total_count
Use binary search to find the median position. Each iteration:
Pick a pivot value
Count how many values are < pivot (distributed across workers)
Adjust search range based on count
def distributed_quickselect(local_histogram, worker_id, num_workers, target_position):
"""
Find the value at target_position using distributed quickselect.
Args:
local_histogram: Counter of values on this worker
worker_id: Current worker ID
num_workers: Total workers
target_position: Position to find (0-indexed)
Returns:
The value at target_position in sorted order
"""
if not local_histogram:
# This worker has no values, participate in coordination only
min_val, max_val = float('inf'), float('-inf')
else:
min_val, max_val = min(local_histogram.keys()), max(local_histogram.keys())
# Find global min/max
if worker_id == 0:
global_min = min_val
global_max = max_val
for _ in range(num_workers - 1):
remote_min, remote_max = recv()
global_min = min(global_min, remote_min)
global_max = max(global_max, remote_max)
# Broadcast to all
for i in range(1, num_workers):
send(i, (global_min, global_max))
else:
send(0, (min_val, max_val))
global_min, global_max = recv()
# Binary search for median
low, high = global_min, global_max
while low < high:
# Choose pivot (coordinator worker 0 decides)
if worker_id == 0:
pivot = (low + high) // 2
for i in range(1, num_workers):
send(i, pivot)
else:
pivot = recv()
# Count values < pivot, == pivot, > pivot locally
count_less = sum(count for val, count in local_histogram.items() if val < pivot)
count_equal = local_histogram.get(pivot, 0)
count_greater = sum(count for val, count in local_histogram.items() if val > pivot)
# Aggregate counts on worker 0
if worker_id == 0:
total_less = count_less
total_equal = count_equal
total_greater = count_greater
for _ in range(num_workers - 1):
remote_less, remote_equal, remote_greater = recv()
total_less += remote_less
total_equal += remote_equal
total_greater += remote_greater
# Decide next range
if target_position < total_less:
# Median is in lower half
new_low, new_high = low, pivot - 1
elif target_position < total_less + total_equal:
# Found the median!
new_low, new_high = pivot, pivot
else:
# Median is in upper half
new_low, new_high = pivot + 1, high
# Broadcast decision
for i in range(1, num_workers):
send(i, (new_low, new_high))
low, high = new_low, new_high
else:
send(0, (count_less, count_equal, count_greater))
low, high = recv()
return low
def find_median_distributed(local_data, worker_id, num_workers):
"""
Complete distributed median finding algorithm.
Args:
local_data: List of integers on this worker
worker_id: Current worker's ID
num_workers: Total number of workers
Returns:
The median value (all workers compute the same result through
distributed consensus - no explicit broadcast needed)
"""
# Build distributed histogram
local_histogram, total_count = build_distributed_histogram(
local_data, worker_id, num_workers
)
# Handle empty dataset
if total_count == 0:
return None # or raise, depending on spec
# Define even-N behavior explicitly:
# - Upper median (default here): idx = total_count // 2
# - Lower median: idx = (total_count - 1) // 2
# - Average of middles: compute both and average (requires two selects)
# Choose the convention required by the spec/interviewer.
median_position = total_count // 2 # upper median by default (0-indexed)
# Use distributed quickselect
median = distributed_quickselect(
local_histogram, worker_id, num_workers, median_position
)
return median
Complexity Analysis:
Time: O(log(max_value - min_value)) iterations × O(unique_keys/W) per iteration
Communication: O(log(max_value - min_value)) rounds, O(1) data per worker per round (O(W) total per round)
Space: O(unique_keys/W) per worker
Note on Return Value: All workers converge to the same median value through the distributed binary search process. Worker 0 coordinates the search by broadcasting pivot and range updates, ensuring all workers maintain synchronized low and high values. When the loop terminates, low == high == median on all workers.
If the value range is small or known, you can use a more efficient approach:
def find_median_sorted_ranges(local_histogram, worker_id, num_workers, total_count):
"""
Alternative: Each worker sends sorted values to worker 0.
Only efficient if unique_keys is small relative to num_workers.
"""
# Each worker sends its sorted (value, count) pairs to worker 0
sorted_local = sorted(local_histogram.items())
if worker_id == 0:
# Collect all sorted ranges
all_values = []
all_values.extend(sorted_local)
for _ in range(num_workers - 1):
remote_sorted = recv()
all_values.extend(remote_sorted)
# Merge and find median
all_values.sort(key=lambda x: x[0])
median_position = total_count // 2
cumulative = 0
for value, count in all_values:
cumulative += count
if cumulative > median_position:
return value
else:
send(0, sorted_local)
return None
Trade-off: This approach is simpler but creates a bottleneck at worker 0. Use only when unique values are few.
The following questions are for verbal discussion. Be prepared to explain your approach and reasoning.
Question: What if the data distribution is highly skewed, where one worker gets most of the work after the modulo-based shuffle?
Discussion Points:
Problem: If many values hash to the same worker ID, that worker becomes a bottleneck
Solutions:
Use hash functions designed for better distribution (e.g., MurmurHash, CityHash)
Multi-level hashing: use a stable hash (e.g., MurmurHash) % num_workers instead of key % num_workers (avoid Python's built-in hash() across processes)
Dynamic load balancing: Monitor worker load and reassign buckets
Increase parallelism: Use more workers or sub-partition within workers
Detection: Workers report their load to coordinator; detect imbalance
Trade-offs: Complexity vs performance improvement
Question: What if a worker crashes during execution? How would you make the system fault-tolerant?
Discussion Points:
Checkpointing: Save state after each phase
After phase 1: Save local counters
After phase 2: Save received data
Heartbeats: Workers ping coordinator periodically
Replication: Store data on multiple workers
Retry mechanism:
Coordinator detects failed worker
Reassign work to healthy workers
Failed worker's data can be recomputed from original source
Trade-offs:
Performance overhead vs reliability
Storage cost for replication
Recovery time vs checkpoint frequency
Question: How would you optimize the algorithm differently for:
Very small datasets (fits in memory of one machine)
Very large datasets (billions of values)
Streaming data (continuously arriving values)
Discussion Points:
Small datasets:
Gather all data to one machine and compute locally
Skip the shuffle phase entirely
Use simpler algorithms (sort + count)
Very large datasets:
Use multi-level aggregation (tree reduction instead of single coordinator)
Compress data before sending (run-length encoding for repeated values)
Use approximate algorithms (Count-Min Sketch, HyperLogLog for cardinality)
Sample-based approaches for median (sample → estimate → verify)
Streaming data:
Maintain sliding windows
Use online algorithms:
For mode: Count-Min Sketch with decay
For median: Two-heap approach (max-heap + min-heap)
Incremental updates instead of full recomputation
Trade accuracy for speed (approximate quantiles)
Question: The current implementation has multiple synchronization points where all workers wait. How can we reduce communication overhead and latency?
Discussion Points:
Asynchronous communication:
Use non-blocking send/recv
Pipeline phases: Start phase N+1 while phase N is finishing
Overlap computation and communication
Batching:
Accumulate multiple messages before sending
Reduce number of send/recv calls
Trade-off: Latency vs throughput
Communication patterns:
Tree-based reduction instead of star topology (all → worker 0)
Ring-based algorithms for all-to-all communication
Butterfly network for parallel prefix operations
Compression:
Compress data before transmission
Use delta encoding for sorted values
Protocol buffers or other efficient serialization
Local caching:
Cache frequently accessed values
Reduce redundant queries to remote workers
Question: How would you extend this framework to compute other statistics like:
Top K most frequent values
Percentiles (P50, P90, P99)
Variance and standard deviation
Distinct count (cardinality)
Discussion Points:
Top K most frequent:
Each worker maintains local top K
Shuffle phase similar to mode
Final reduction: merge K sorted lists from each worker
Use min-heap for efficient top-K tracking
Percentiles:
Similar to median (which is P50)
Distributed quantile sketch (T-Digest, KLL sketch)
Approximate with error bounds
For multiple percentiles: single pass with multiple targets
Variance/Standard deviation:
Compute local mean, variance, count
Use parallel variance formula:
Var(X∪Y) = (n₁·var₁ + n₂·var₂)/(n₁+n₂) + n₁·n₂·(μ₁-μ₂)²/(n₁+n₂)²
Tree reduction to combine statistics
Numerically stable algorithms (Welford's method)
Distinct count:
HyperLogLog algorithm (probabilistic)
O(log log N) space complexity
Merge HyperLogLog sketches from each worker
Trade-off: 2% error for massive space savings
For exact count: Similar to mode (but only track presence, not count)
Question: What additional considerations would you need to address in a production system?
Discussion Points:
Monitoring and observability:
Metrics: throughput, latency, data transferred
Tracing: Track data flow across workers
Alerting: Detect stragglers, failures
Dashboards: Visualize system health
Resource management:
Memory limits per worker
CPU throttling to avoid overload
Network bandwidth allocation
Spilling to disk for large intermediate data
Data quality:
Handle missing values (NULL, NaN)
Detect and handle outliers
Data validation before processing
Schema enforcement
Security:
Authentication between workers
Encryption for data in transit
Access control for sensitive data
Audit logging
Testing:
Unit tests for each phase
Integration tests with multiple workers
Chaos testing (random failures)
Performance benchmarks
Correctness tests with known datasets
Communication protocol considerations:
Deadlock prevention: If send() is blocking and buffers are limited, an "all-send then all-recv" pattern can deadlock. Solutions:
Interleave send/recv operations
Use non-blocking I/O
Implement flow control mechanisms
Empty bucket optimization: Skip sending empty buckets in Phase 2 to reduce overhead:
Include metadata (bucket count or termination signal) so receivers know when to stop
Or use tagged messages to identify sender
Hash function selection:
For integer keys: Direct modulo (key % num_workers) is sufficient for uniformly distributed integers
For non-uniform data or non-integer keys: Use cryptographic or mixing hash functions (MurmurHash, CityHash)
Cross-process compatibility: Avoid Python's built-in hash() across processes/machines - it's randomized per run for security
Alternative: Use hashlib for consistent hashing across processes:
import hashlib
def stable_hash(key):
return int(hashlib.md5(str(key).encode()).hexdigest(), 16)
target_worker = stable_hash(key) % num_workers
Scalability patterns:
Tree-based reduction: For 100+ workers, replace star topology with tree reduction in Phase 4:
Reduces worker 0 load from O(W) to O(log W)
Total rounds: O(log W) instead of 1
Trade-off: More complex coordination for better scalability
Hierarchical aggregation: Multi-level grouping for very large clusters (1000+ workers)
Question: How does this approach compare to using a MapReduce framework like Hadoop or Spark?
Discussion Points:
Similarities:
Both use map (local processing) and reduce (aggregation) phases
Both distribute work across multiple machines
Both use shuffle/partition for data redistribution
Differences:
Specialization: Our solution is specialized for mode/median
Framework overhead: MapReduce has more abstraction layers
Flexibility: MapReduce is more general-purpose
Optimization: We can optimize for our specific use case
When to use each:
Custom solution: When you need maximum performance for a specific task
MapReduce/Spark: When you need flexibility, fault tolerance, and have varied workloads
Hybrid: Use Spark for data engineering, custom code for performance-critical algorithms
Modern alternatives:
Apache Flink for streaming
Dask for Python-native distributed computing
Ray for general-purpose distributed applications
Presto/Trino for distributed SQL queries