Back

Distributed Mode and Median

Low-Level DesignCodingOnsitePhoneMachine Learning EngineerSoftware EngineerReported May, 2026High Frequency

Basic Problem

You are given:

A very large dataset distributed across multiple machines (typically 10 workers)

Pre-built interface functions: send(worker_id, data) and recv()

Each machine has a portion of the dataset stored locally

Implement a distributed system to find the mode (most frequent value) of the entire dataset.

Critical Constraint: Data transmission speed is the system's bottleneck. You must avoid having all machines simultaneously send data to a single machine, as this creates severe reception delays.

Below is the interface you must use:

# Send data to a specific worker
def send(worker_id, data):
    pass

# Receive data from any worker
def recv():
    pass

Problem Setup

Dataset characteristics:

The dataset contains mostly unique integers

One integer appears twice (duplicates) but on different machines

All other values appear only once

Environment: Python

Worker count: 10 machines (configurable via WORKER_NUM)

Current worker ID: Available as worker_id

Requirements

Your solution should:

Use the provided interface — do not modify or reimplement send() and recv()

Avoid bottlenecks — distribute communication load evenly across all workers

Maximize parallelism — workers should send and receive simultaneously when possible

Handle latency — account for both send and receive operation latency

Return the mode — the value that appears most frequently in the entire dataset

Note for Candidate

Bandwidth management: The modulo-based distribution strategy is crucial to avoid bottlenecks

Clarify with interviewer: What should be returned if multiple values have the same highest count (ties)? Should we return:

Any one of them?

The smallest/largest value?

All values with the maximum count?

Top K vs Top 1: After repartitioning by key (Phase 2), sending only top 1 from each worker is sufficient to recover the global mode because each worker holds complete counts for a disjoint key set. Without this repartition, per-worker top-1 is not sufficient when the duplicated value is split across machines. If the task requires the K most frequent values or deterministic tie-breaking, send top K (or encode a tie-break rule).

Detailed Implementation Strategy: Mode

Phase 1: Local Processing

Each worker independently processes its local data partition.

from collections import Counter

def phase1_local_count(local_data):
    """
    Each worker counts frequencies in its local partition.

    Args:
        local_data: List of integers stored on this worker

    Returns:
        Counter object: {value: count}
    """
    return Counter(local_data)

Example:

Worker 0 has [1, 2, 3, 5, 10] → {1:1, 2:1, 3:1, 5:1, 10:1}
Worker 1 has [2, 4, 6, 8] → {2:1, 4:1, 6:1, 8:1}

Phase 2: Distributed Shuffle (Key Innovation)

Each worker redistributes its local counts to the appropriate target workers using modulo hashing. This ensures:

All identical keys end up on the same worker

All workers send and receive simultaneously (no bottleneck)

def phase2_shuffle_counts(local_counter, worker_id, num_workers):
    """
    Redistribute counts to workers based on key % num_workers.

    This ensures that the same key always goes to the same worker,
    and communication is balanced across all workers.

    Args:
        local_counter: Counter from phase 1
        worker_id: Current worker's ID
        num_workers: Total number of workers
    """
    # Prepare buckets for each target worker
    buckets = [[] for _ in range(num_workers)]

    # Distribute keys based on modulo
    for key, count in local_counter.items():
        target_worker = key % num_workers
        buckets[target_worker].append((key, count))

    # Send to all other workers
    for target_id in range(num_workers):
        if target_id != worker_id:
            send(target_id, buckets[target_id])

    # Receive from all other workers
    received_data = []
    for _ in range(num_workers - 1):
        data = recv()
        received_data.extend(data)

    # Add local bucket (no network transfer needed)
    received_data.extend(buckets[worker_id])

    return received_data

Why this works:

Worker i receives all keys where key % num_workers == i

If value X appears on workers 3 and 7, both will send their count for X to the same target worker

All workers are sending and receiving simultaneously

Phase 3: Local Aggregation

Each worker aggregates all counts it received for its assigned keys and finds its local maximum.

def phase3_aggregate_local(received_data):
    """
    Aggregate counts for all keys assigned to this worker.

    Args:
        received_data: List of (key, count) tuples from phase 2

    Returns:
        Counter with aggregated counts
        Tuple of (key, count) for the most frequent value in this worker's partition
    """
    # Aggregate all counts for the same key
    aggregated = Counter()
    for key, count in received_data:
        aggregated[key] += count

    # Get the most frequent value in this worker's partition
    # Since keys are partitioned by modulo, the global mode must be
    # one of the local top-1 values across all workers
    if aggregated:
        top_1 = aggregated.most_common(1)[0]
    else:
        top_1 = (None, 0)  # Empty partition

    return aggregated, top_1

Why top 1 is sufficient:

After Phase 2, each worker owns a disjoint set of keys (partitioned by key % num_workers)

Each worker has the complete, final count for all keys assigned to it

The global mode must be the maximum among all workers' local maximums

No need to send top K since keys don't overlap between workers

Phase 4: Global Reduction

All workers send their local top 1 to worker 0, which computes the global mode.

def phase4_global_reduction(local_top_1, worker_id, num_workers):
    """
    Gather top 1 from each worker to find global mode.

    Args:
        local_top_1: (key, count) tuple - this worker's most frequent value
        worker_id: Current worker's ID
        num_workers: Total number of workers

    Returns:
        Global mode (only on worker 0, None on others)
    """
    if worker_id == 0:
        # Start with local top 1
        best_key, best_count = local_top_1

        # Receive from all other workers (only 1 value each)
        for _ in range(num_workers - 1):
            remote_key, remote_count = recv()
            # Deterministic tie-breaking: prefer smaller key when counts are equal
            if remote_count > best_count or (remote_count == best_count and remote_key is not None and best_key is not None and remote_key < best_key):
                best_key, best_count = remote_key, remote_count

        return best_key
    else:
        # Send only top 1 to worker 0
        send(0, local_top_1)
        return None

Optimization: Only O(num_workers) values are sent in this phase (one per worker), making it extremely lightweight compared to sending top K.

Note: For larger-scale systems (100+ workers), consider using a tree-based reduction with O(log W) rounds to avoid a single receiver bottleneck. See "Real-World Considerations" section below.

Complete Mode Implementation

def find_mode_distributed(local_data, worker_id, num_workers):
    """
    Complete distributed mode finding algorithm.

    Args:
        local_data: List of integers on this worker
        worker_id: Current worker's ID (0 to num_workers-1)
        num_workers: Total number of workers

    Returns:
        The mode of the entire dataset (only on worker 0)
    """
    # Phase 1: Count local frequencies
    local_counter = phase1_local_count(local_data)

    # Phase 2: Shuffle counts to appropriate workers
    received_data = phase2_shuffle_counts(local_counter, worker_id, num_workers)

    # Phase 3: Aggregate and find local top 1
    aggregated, local_top_1 = phase3_aggregate_local(received_data)

    # Phase 4: Global reduction to find mode
    mode = phase4_global_reduction(local_top_1, worker_id, num_workers)

    return mode

Complexity Analysis:

Time: O(N/W + unique_keys/W) per worker, where N = total data size, W = num_workers

Phase 1: O(N/W) to build Counter

Phase 2: O(unique_keys/W) to shuffle

Phase 3: O(unique_keys/W) to aggregate and find max

Phase 4: O(W) on worker 0, O(1) on others

Communication:

Phase 2: O(unique_keys/W) data sent per worker (the main bottleneck)

Phase 4: O(1) data per worker (only one (key, count) tuple)

Space: O(unique_keys/W) per worker

Note on When to Use Top K:

If the problem requires finding the K most frequent values, modify Phase 3 to return most_common(K) and Phase 4 to merge these lists

If ties are common and you need all values with the maximum count, send enough candidates to cover potential ties

For the standard mode problem (single most frequent value), top 1 is optimal

Follow-up: Finding the Median

After implementing the mode solution, extend your system to find the median of the distributed dataset.

Key Insight

Unlike mode, the median requires understanding the global ordering of values. Use a distributed quickselect algorithm that iteratively narrows down the median's location without sorting the entire dataset.

Detailed Implementation Strategy: Median

Overview

The median is the value at position total_count / 2 when all values are sorted. The algorithm:

Builds a distributed histogram (value → count mapping)

Uses binary search with distributed counting to find the median

Phase 1 & 2: Build Distributed Histogram

Reuse the mode-finding phases 1-2 to create a distributed histogram where each worker owns a range of values.

def build_distributed_histogram(local_data, worker_id, num_workers):
    """
    Create distributed histogram using same approach as mode phases 1-2.

    Returns:
        local_histogram: Counter for values assigned to this worker
        total_count: Total number of elements (computed collectively)
    """
    # Phase 1: Local counting
    local_counter = Counter(local_data)

    # Phase 2: Shuffle to appropriate workers
    received_data = phase2_shuffle_counts(local_counter, worker_id, num_workers)

    # Aggregate
    local_histogram = Counter()
    for key, count in received_data:
        local_histogram[key] += count

    # Compute total count (all workers send to worker 0)
    if worker_id == 0:
        total_count = sum(count for count in local_histogram.values())
        for _ in range(num_workers - 1):
            remote_count = recv()
            total_count += remote_count
        # Broadcast total to all workers
        for i in range(1, num_workers):
            send(i, total_count)
        return local_histogram, total_count
    else:
        my_count = sum(count for count in local_histogram.values())
        send(0, my_count)
        total_count = recv()
        return local_histogram, total_count

Phase 3: Distributed Quickselect

Use binary search to find the median position. Each iteration:

Pick a pivot value

Count how many values are < pivot (distributed across workers)

Adjust search range based on count

def distributed_quickselect(local_histogram, worker_id, num_workers, target_position):
    """
    Find the value at target_position using distributed quickselect.

    Args:
        local_histogram: Counter of values on this worker
        worker_id: Current worker ID
        num_workers: Total workers
        target_position: Position to find (0-indexed)

    Returns:
        The value at target_position in sorted order
    """
    if not local_histogram:
        # This worker has no values, participate in coordination only
        min_val, max_val = float('inf'), float('-inf')
    else:
        min_val, max_val = min(local_histogram.keys()), max(local_histogram.keys())

    # Find global min/max
    if worker_id == 0:
        global_min = min_val
        global_max = max_val
        for _ in range(num_workers - 1):
            remote_min, remote_max = recv()
            global_min = min(global_min, remote_min)
            global_max = max(global_max, remote_max)
        # Broadcast to all
        for i in range(1, num_workers):
            send(i, (global_min, global_max))
    else:
        send(0, (min_val, max_val))
        global_min, global_max = recv()

    # Binary search for median
    low, high = global_min, global_max

    while low < high:
        # Choose pivot (coordinator worker 0 decides)
        if worker_id == 0:
            pivot = (low + high) // 2
            for i in range(1, num_workers):
                send(i, pivot)
        else:
            pivot = recv()

        # Count values < pivot, == pivot, > pivot locally
        count_less = sum(count for val, count in local_histogram.items() if val < pivot)
        count_equal = local_histogram.get(pivot, 0)
        count_greater = sum(count for val, count in local_histogram.items() if val > pivot)

        # Aggregate counts on worker 0
        if worker_id == 0:
            total_less = count_less
            total_equal = count_equal
            total_greater = count_greater

            for _ in range(num_workers - 1):
                remote_less, remote_equal, remote_greater = recv()
                total_less += remote_less
                total_equal += remote_equal
                total_greater += remote_greater

            # Decide next range
            if target_position < total_less:
                # Median is in lower half
                new_low, new_high = low, pivot - 1
            elif target_position < total_less + total_equal:
                # Found the median!
                new_low, new_high = pivot, pivot
            else:
                # Median is in upper half
                new_low, new_high = pivot + 1, high

            # Broadcast decision
            for i in range(1, num_workers):
                send(i, (new_low, new_high))

            low, high = new_low, new_high
        else:
            send(0, (count_less, count_equal, count_greater))
            low, high = recv()

    return low

Complete Median Implementation

def find_median_distributed(local_data, worker_id, num_workers):
    """
    Complete distributed median finding algorithm.

    Args:
        local_data: List of integers on this worker
        worker_id: Current worker's ID
        num_workers: Total number of workers

    Returns:
        The median value (all workers compute the same result through
        distributed consensus - no explicit broadcast needed)
    """
    # Build distributed histogram
    local_histogram, total_count = build_distributed_histogram(
        local_data, worker_id, num_workers
    )

    # Handle empty dataset
    if total_count == 0:
        return None  # or raise, depending on spec

    # Define even-N behavior explicitly:
    # - Upper median (default here): idx = total_count // 2
    # - Lower median: idx = (total_count - 1) // 2
    # - Average of middles: compute both and average (requires two selects)
    # Choose the convention required by the spec/interviewer.
    median_position = total_count // 2  # upper median by default (0-indexed)

    # Use distributed quickselect
    median = distributed_quickselect(
        local_histogram, worker_id, num_workers, median_position
    )

    return median

Complexity Analysis:

Time: O(log(max_value - min_value)) iterations × O(unique_keys/W) per iteration

Communication: O(log(max_value - min_value)) rounds, O(1) data per worker per round (O(W) total per round)

Space: O(unique_keys/W) per worker

Note on Return Value: All workers converge to the same median value through the distributed binary search process. Worker 0 coordinates the search by broadcasting pivot and range updates, ensuring all workers maintain synchronized low and high values. When the loop terminates, low == high == median on all workers.

Alternative Approach: Sorted Ranges

If the value range is small or known, you can use a more efficient approach:

def find_median_sorted_ranges(local_histogram, worker_id, num_workers, total_count):
    """
    Alternative: Each worker sends sorted values to worker 0.

    Only efficient if unique_keys is small relative to num_workers.
    """
    # Each worker sends its sorted (value, count) pairs to worker 0
    sorted_local = sorted(local_histogram.items())

    if worker_id == 0:
        # Collect all sorted ranges
        all_values = []
        all_values.extend(sorted_local)

        for _ in range(num_workers - 1):
            remote_sorted = recv()
            all_values.extend(remote_sorted)

        # Merge and find median
        all_values.sort(key=lambda x: x[0])

        median_position = total_count // 2
        cumulative = 0
        for value, count in all_values:
            cumulative += count
            if cumulative > median_position:
                return value
    else:
        send(0, sorted_local)
        return None

Trade-off: This approach is simpler but creates a bottleneck at worker 0. Use only when unique values are few.

Additional Discussion Topics

The following questions are for verbal discussion. Be prepared to explain your approach and reasoning.

1. Handling Data Skew

Question: What if the data distribution is highly skewed, where one worker gets most of the work after the modulo-based shuffle?

Discussion Points:

Problem: If many values hash to the same worker ID, that worker becomes a bottleneck

Solutions:

Use hash functions designed for better distribution (e.g., MurmurHash, CityHash)

Multi-level hashing: use a stable hash (e.g., MurmurHash) % num_workers instead of key % num_workers (avoid Python's built-in hash() across processes)

Dynamic load balancing: Monitor worker load and reassign buckets

Increase parallelism: Use more workers or sub-partition within workers

Detection: Workers report their load to coordinator; detect imbalance

Trade-offs: Complexity vs performance improvement

2. Fault Tolerance

Question: What if a worker crashes during execution? How would you make the system fault-tolerant?

Discussion Points:

Checkpointing: Save state after each phase

After phase 1: Save local counters

After phase 2: Save received data

Heartbeats: Workers ping coordinator periodically

Replication: Store data on multiple workers

Retry mechanism:

Coordinator detects failed worker

Reassign work to healthy workers

Failed worker's data can be recomputed from original source

Trade-offs:

Performance overhead vs reliability

Storage cost for replication

Recovery time vs checkpoint frequency

3. Optimizing for Different Data Sizes

Question: How would you optimize the algorithm differently for:

Very small datasets (fits in memory of one machine)

Very large datasets (billions of values)

Streaming data (continuously arriving values)

Discussion Points:

Small datasets:

Gather all data to one machine and compute locally

Skip the shuffle phase entirely

Use simpler algorithms (sort + count)

Very large datasets:

Use multi-level aggregation (tree reduction instead of single coordinator)

Compress data before sending (run-length encoding for repeated values)
Use approximate algorithms (Count-Min Sketch, HyperLogLog for cardinality)
Sample-based approaches for median (sample → estimate → verify)

Streaming data:

Maintain sliding windows

Use online algorithms:

For mode: Count-Min Sketch with decay

For median: Two-heap approach (max-heap + min-heap)

Incremental updates instead of full recomputation

Trade accuracy for speed (approximate quantiles)

4. Communication Optimization

Question: The current implementation has multiple synchronization points where all workers wait. How can we reduce communication overhead and latency?

Discussion Points:

Asynchronous communication:

Use non-blocking send/recv

Pipeline phases: Start phase N+1 while phase N is finishing

Overlap computation and communication

Batching:

Accumulate multiple messages before sending

Reduce number of send/recv calls

Trade-off: Latency vs throughput

Communication patterns:

Tree-based reduction instead of star topology (all → worker 0)

Ring-based algorithms for all-to-all communication

Butterfly network for parallel prefix operations

Compression:

Compress data before transmission

Use delta encoding for sorted values

Protocol buffers or other efficient serialization

Local caching:

Cache frequently accessed values

Reduce redundant queries to remote workers

5. Extending to Other Statistics

Question: How would you extend this framework to compute other statistics like:

Top K most frequent values

Percentiles (P50, P90, P99)

Variance and standard deviation

Distinct count (cardinality)

Discussion Points:

Top K most frequent:

Each worker maintains local top K

Shuffle phase similar to mode

Final reduction: merge K sorted lists from each worker

Use min-heap for efficient top-K tracking

Percentiles:

Similar to median (which is P50)

Distributed quantile sketch (T-Digest, KLL sketch)

Approximate with error bounds

For multiple percentiles: single pass with multiple targets

Variance/Standard deviation:

Compute local mean, variance, count

Use parallel variance formula:

Var(X∪Y) = (n₁·var₁ + n₂·var₂)/(n₁+n₂) + n₁·n₂·(μ₁-μ₂)²/(n₁+n₂)²

Tree reduction to combine statistics

Numerically stable algorithms (Welford's method)

Distinct count:

HyperLogLog algorithm (probabilistic)

O(log log N) space complexity

Merge HyperLogLog sketches from each worker

Trade-off: 2% error for massive space savings

For exact count: Similar to mode (but only track presence, not count)

6. Real-World Considerations

Question: What additional considerations would you need to address in a production system?

Discussion Points:

Monitoring and observability:

Metrics: throughput, latency, data transferred

Tracing: Track data flow across workers

Alerting: Detect stragglers, failures

Dashboards: Visualize system health

Resource management:

Memory limits per worker

CPU throttling to avoid overload

Network bandwidth allocation

Spilling to disk for large intermediate data

Data quality:

Handle missing values (NULL, NaN)

Detect and handle outliers

Data validation before processing

Schema enforcement

Security:

Authentication between workers

Encryption for data in transit

Access control for sensitive data

Audit logging

Testing:

Unit tests for each phase

Integration tests with multiple workers

Chaos testing (random failures)

Performance benchmarks

Correctness tests with known datasets

Communication protocol considerations:

Deadlock prevention: If send() is blocking and buffers are limited, an "all-send then all-recv" pattern can deadlock. Solutions:

Interleave send/recv operations

Use non-blocking I/O

Implement flow control mechanisms

Empty bucket optimization: Skip sending empty buckets in Phase 2 to reduce overhead:

Include metadata (bucket count or termination signal) so receivers know when to stop

Or use tagged messages to identify sender

Hash function selection:

For integer keys: Direct modulo (key % num_workers) is sufficient for uniformly distributed integers

For non-uniform data or non-integer keys: Use cryptographic or mixing hash functions (MurmurHash, CityHash)

Cross-process compatibility: Avoid Python's built-in hash() across processes/machines - it's randomized per run for security

Alternative: Use hashlib for consistent hashing across processes:

import hashlib

def stable_hash(key):
    return int(hashlib.md5(str(key).encode()).hexdigest(), 16)

target_worker = stable_hash(key) % num_workers

Scalability patterns:

Tree-based reduction: For 100+ workers, replace star topology with tree reduction in Phase 4:

Reduces worker 0 load from O(W) to O(log W)

Total rounds: O(log W) instead of 1

Trade-off: More complex coordination for better scalability

Hierarchical aggregation: Multi-level grouping for very large clusters (1000+ workers)

7. Comparison with MapReduce

Question: How does this approach compare to using a MapReduce framework like Hadoop or Spark?

Discussion Points:

Similarities:

Both use map (local processing) and reduce (aggregation) phases

Both distribute work across multiple machines

Both use shuffle/partition for data redistribution

Differences:

Specialization: Our solution is specialized for mode/median

Framework overhead: MapReduce has more abstraction layers

Flexibility: MapReduce is more general-purpose

Optimization: We can optimize for our specific use case

When to use each:

Custom solution: When you need maximum performance for a specific task

MapReduce/Spark: When you need flexibility, fault tolerance, and have varied workloads

Hybrid: Use Spark for data engineering, custom code for performance-critical algorithms

Modern alternatives:

Apache Flink for streaming

Dask for Python-native distributed computing

Ray for general-purpose distributed applications

Presto/Trino for distributed SQL queries


Auto-save enabled
Loading editor…
Output
Run your code to see the output here.