Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ShardManagerTokenAware class to split shards along node-token boundaries #1255

Open
wants to merge 24 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
75f7afb
WIP: Initial ShardManagerNodeAware implemenation
michaeljmarshall Aug 30, 2024
1581fb2
First pass clean up after review
michaeljmarshall Sep 3, 2024
ec6adc1
Throw exception if compaction strategy is node aware and has disk pos…
michaeljmarshall Sep 3, 2024
514220e
Add IsolatedTokenAllocator; fix snitch override in TokenMetadata
michaeljmarshall Sep 4, 2024
73a1b15
Fix edge cases after manual testing
michaeljmarshall Sep 4, 2024
ad1a5fa
Add is_node_aware to UCS controller config
michaeljmarshall Sep 6, 2024
e5ae871
Address review feedback
michaeljmarshall Sep 6, 2024
a7242f5
Add test with rf 3 and 3 racks
michaeljmarshall Sep 9, 2024
24afc09
Merge remote-tracking branch 'datastax/main' into node-aware-shard-ma…
michaeljmarshall Sep 9, 2024
9d38e0d
Add temporary debug logging
michaeljmarshall Sep 9, 2024
1c4fb8e
Cleanup original isNodeAware implementation
michaeljmarshall Sep 9, 2024
22ba8af
Make tests compile
michaeljmarshall Sep 10, 2024
43fc9d4
Add some additional debug logs
michaeljmarshall Sep 10, 2024
97063dc
Fix logic for computeUniformSplitPoints
michaeljmarshall Sep 10, 2024
c42c833
Wrap code that seems to fail with try--add more logging
michaeljmarshall Sep 13, 2024
038e539
Use i, not pos, to get entry from array
michaeljmarshall Sep 13, 2024
778d47b
Fix findTokenAlignedSplitPoints; add tests
michaeljmarshall Sep 17, 2024
50b548d
Rename ShardManagerNodeAware to ShardManagerTokenAware; add basic cac…
michaeljmarshall Sep 17, 2024
6f6d711
Cleanup
michaeljmarshall Sep 17, 2024
8c4b186
Enable DEFAULT_IS_NODE_AWARE for testing
michaeljmarshall Sep 18, 2024
073452c
Fix rangeSpanned implementaitons
michaeljmarshall Sep 18, 2024
43624fc
Make tests pass; not sure if we want these changes though
michaeljmarshall Sep 19, 2024
46de9f6
Revert "Make tests pass; not sure if we want these changes though"
michaeljmarshall Sep 19, 2024
5ac75cd
Relocate ShardManagerTokenAware init to prevent invalid usage
michaeljmarshall Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,13 @@
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.dht.Range;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.locator.AbstractReplicationStrategy;

public interface ShardManager
{
// Config to enable using node aware shard manager
static final boolean ENABLE_NODE_AWARE_SHARD_MANAGER = Boolean.parseBoolean(System.getProperty("cassandra.enable_node_aware_shard_manager", "false"));

/**
* Single-partition, and generally sstables with very few partitions, can cover very small sections of the token
* space, resulting in very high densities.
Expand All @@ -40,8 +44,11 @@ public interface ShardManager
*/
static final double MINIMUM_TOKEN_COVERAGE = Math.scalb(1.0, -48);

static ShardManager create(DiskBoundaries diskBoundaries)
static ShardManager create(DiskBoundaries diskBoundaries, AbstractReplicationStrategy rs)
{
// TODO do we need to deal with DiskBoundaries in astra?
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
if (ENABLE_NODE_AWARE_SHARD_MANAGER)
return new ShardManagerNodeAware(rs);
List<Token> diskPositions = diskBoundaries.getPositions();
SortedLocalRanges localRanges = diskBoundaries.getLocalRanges();
IPartitioner partitioner = localRanges.getRealm().getPartitioner();
Expand Down
280 changes: 280 additions & 0 deletions src/java/org/apache/cassandra/db/compaction/ShardManagerNodeAware.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.db.compaction;

import java.util.Arrays;
import java.util.Set;

import javax.annotation.Nullable;

import org.apache.cassandra.db.PartitionPosition;
import org.apache.cassandra.dht.IPartitioner;
import org.apache.cassandra.dht.Range;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.dht.tokenallocator.TokenAllocation;
import org.apache.cassandra.io.sstable.format.SSTableReader;
import org.apache.cassandra.io.sstable.format.SSTableWriter;
import org.apache.cassandra.locator.AbstractReplicationStrategy;
import org.apache.cassandra.locator.TokenMetadata;

/**
* A shard manager implementation that accepts token-allocator-generated-tokens and splits along them to ensure that
* current and future states of the cluster will have sstables within shards, not across them, for sufficiently high
* levels of compaction, which allows nodes to trivially own complete sstables for sufficiently high levels of
* compaction.
*
* If there are not yet enough tokens allocated, use the {@link org.apache.cassandra.dht.tokenallocator.TokenAllocator}
* to allocate more tokens to split along. The key to this implementation is utilizing the same algorithm to allocate
* tokens to nodes and to split ranges for higher levels of compaction.
*/
// I haven't figured out yet whether the interesting part of this class is the fact that we use the token allocator
// to find higher level splits or if it is the node awareness. Is it possible to remove the node awareness and keep
// the allocator's logic or do we need both?
// TODO should we extend ShardManagerDiskAware?

public class ShardManagerNodeAware implements ShardManager
{
public static final Token[] TOKENS = new Token[0];
private final AbstractReplicationStrategy rs;
private final TokenMetadata tokenMetadata;

public ShardManagerNodeAware(AbstractReplicationStrategy rs)
{
this.rs = rs;
this.tokenMetadata = rs.getTokenMetadata();
}

@Override
public double rangeSpanned(Range<Token> tableRange)
{
return tableRange.left.size(tableRange.right);
}

@Override
public double localSpaceCoverage()
{
// At the moment, this is global, so it covers the whole range. Might not be right though.
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
return 1;
}

@Override
public double shardSetCoverage()
{
// For now there are no disks defined, so this is the same as localSpaceCoverage
return 1;
}

@Override
public ShardTracker boundaries(int shardCount)
{
var splitPointCount = shardCount - 1;
// TODO is it safe to get tokens here and then endpoints later without synchronization?
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
var sortedTokens = tokenMetadata.sortedTokens();
if (splitPointCount > sortedTokens.size())
{
// Need to allocate tokens within node boundaries.
var endpoints = tokenMetadata.getAllEndpoints();
double addititionalSplits = splitPointCount - sortedTokens.size();
var splitPointsPerNode = (int) Math.ceil(addititionalSplits / endpoints.size());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To generate additionalSplits many tokens the way that the allocation strategy would do it when new nodes are added, we need to do the following:

  • First, get the number of tokens the cluster uses per node, tokensPerNode. This should be sortedTokens.size() / endPoints.size() or DatabaseDescriptor.getNumTokens() (these two are expected to be the same, we should bail out if they aren't, or rely on explicitly specified value in the UCS configuration).
  • Set splitPointNodes = Math.ceil(additionalSplits / tokensPerNode).
  • Create splitPointNodes many new fake node ids, which need to be assigned in racks round-robin.
  • Get the token allocation to assign tokensPerNode many tokens for each of these fake nodes (i.e. use TokenAllocation.create(...) and then repeatedly call allocate on that object). See OfflineTokenAllocator.MultiNodeAllocator, which starts from empty and generates tokens for the given number of nodes in each rack; we want a variation of this which starts with TokenMetadata that matches the current, and then continues adding fake nodes until the required number are generated.
  • We can then flatten the generated token list and truncate it to exactly the required number of extra tokens, then concatenate it at the end of the original tokens and sort (i.e. prefer just choosing the first additionalSplits many from the new instead of using the selection scheme below).
  • The generated tokens should be cached.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about it the wrong way. I was focused on finding the right splits for the existing nodes instead of just adding fake new nodes until we have enough tokens. This seems more straight forward.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can then flatten the generated token list and truncate it to exactly the required number of extra tokens, then concatenate it at the end of the original tokens and sort (i.e. prefer just choosing the first additionalSplits many from the new instead of using the selection scheme below).

I would have thought the selection scheme would give us better data distribution, as opposed to truncating the list of new tokens. Also, if we truncate the list, does that present issues for ensuring that higher levels of UCS have the same splits as lower levels?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed up a new commit with a proposed implementation. I plan to write tests for it tomorrow.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we truncate the list, does that present issues for ensuring that higher levels of UCS have the same splits as lower levels?

No, it does not, because the same tokens will be generated again when we generate the higher-count levels (in other words, we can cache the smaller generation round (pre-truncation) and use it to save some work when generating the bigger).

I would have thought the selection scheme would give us better data distribution, as opposed to truncating the list of new tokens

The token allocation gives the biggest-impact-token first, which is also the one that splits the current biggest token range so taking them in order should be a good enough choice and we can save the effort. Also, I'm not sure looking for closest to even split can't cause bigger oddities.

// Compute additional tokens since we don't have enough
for (var endpoint : endpoints)
sortedTokens.addAll(TokenAllocation.allocateTokens(tokenMetadata, rs, endpoint, splitPointsPerNode));
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
// Sort again since we added new tokens.
sortedTokens.sort(Token::compareTo);
}
var splitPoints = findTokenAlignedSplitPoints(sortedTokens.toArray(TOKENS), shardCount);
return new NodeAlignedShardTracker(shardCount, splitPoints);
}

private Token[] findTokenAlignedSplitPoints(Token[] sortedTokens, int shardCount)
{
// Short circuit on equal
if (sortedTokens.length == shardCount - 1)
return sortedTokens;
var evenSplitPoints = computeUniformSplitPoints(tokenMetadata.partitioner, shardCount);
var nodeAlignedSplitPoints = new Token[shardCount - 1];

// UCS requires that the splitting points for a given density are also splitting points for
// all higher densities, so we pick from among the existing tokens.
int pos = 0;
for (int i = 0; i < evenSplitPoints.length; i++) {
Token value = evenSplitPoints[i];
pos = Arrays.binarySearch(sortedTokens, pos, evenSplitPoints.length, value);

if (pos >= 0)
{
// Exact match found
nodeAlignedSplitPoints[i] = sortedTokens[pos];
pos++;
}
else
{
// pos is -(insertion point) - 1, so calculate the insertion point
pos = -pos - 1;

// Check the neighbors
Token leftNeighbor = sortedTokens[pos - 1];
Token rightNeighbor = sortedTokens[pos];

// Choose the nearest neighbor. By convention, prefer left if value is midpoint.
if (value.size(leftNeighbor) <= value.size(rightNeighbor))
{
nodeAlignedSplitPoints[i] = leftNeighbor;
// No need to bump pos because we decremented it to find the right split token.
}
else
{
nodeAlignedSplitPoints[i] = rightNeighbor;
pos++;
}
}
}

return nodeAlignedSplitPoints;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't permit multiple evenSplitPoints to map to the same nodeAlignedSplitPoint, because this effectively reduces the number of split points.

One thing we can do is remove entries from the node-aligned list when we use them, and enforce that the ones picked for smaller shard counts are also picked by recursively generating for shardCount / 2 first until shardCount is not divisible by 2.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am pretty sure I already did this, though it might not be in the ideal way. As we iterate over sortedTokens, we find the position of the nearest neighbor, and then we make set pos to the index of the next token in nodeAlignedSplitPoints. We then use pos as the lower bound in the binary search to find the closest split point. I haven't confirmed that this solution maintains the rule that UCS requires that the splitting points for a given density are also splitting points for all higher densities.

I can see that the recursive solution would trivially ensure UCS requires that the splitting points for a given density are also splitting points for all higher densities. Do you think we should prefer that approach?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that mine works assuming we have the power-of-two token allocator.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I think this could work too (if we don't select a specific token, it would be because we selected it for a different split point). Added a couple of comments on ensuring no repetition, and not exhausting the sorted tokens too early.

This method needs to be unit-tested.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hit a bug in this part of the code when working on adding unit tests and I think that it would be better to have a cleaner and clearer design. In general, I don't yet see how we guarantee that we select the same tokens at successively higher levels as tokens are added and removed to the sorted tokens list. Is it safe to assume a certain

One thing we can do is remove entries from the node-aligned list when we use them, and enforce that the ones picked for smaller shard counts are also picked by recursively generating for shardCount / 2 first until shardCount is not divisible by 2.

I tried implementing this, but I ran into trouble because I assumed I would have enough split points if I broke the space up by the nearest token to the midpoint:

    private void findTokenAlignedSplitPoints(Token[] sortedTokens, int min, int max, Token left, Token right, int splitPointCount, List<Token> splitPoints)
    {
        splitPointCount--;
        var midpoint = partitioner.midpoint(left, right);
        var index = Arrays.binarySearch(sortedTokens, min, max, midpoint);
        if (index < 0)
        {
            // -(insertion point) - 1
            System.out.println("Index: " + index + " midpoint: " + midpoint);
            index = -index - 1;
            if (index != 0 && index != sortedTokens.length - 1)
            {
                // Check to see which neighbor is closer
                var leftNeighbor = sortedTokens[index - 1];
                var rightNeighbor = sortedTokens[index];
                index = leftNeighbor.size(midpoint) <= midpoint.size(rightNeighbor) ? index - 1 : index;
            }
        }
        var tokenAlignedMidpoint = sortedTokens[index];
        var leftSplitPointCount = splitPointCount / 2;
        var rightSplitPointCount = splitPointCount - leftSplitPointCount;
        if (leftSplitPointCount > 0)
            findTokenAlignedSplitPoints(sortedTokens, min, index, left, tokenAlignedMidpoint, leftSplitPointCount, splitPoints);
        // Add this split point after finding all split points to the left
        splitPoints.add(tokenAlignedMidpoint);
        if (rightSplitPointCount > 0)
            findTokenAlignedSplitPoints(sortedTokens, index + 1, max, tokenAlignedMidpoint, right, rightSplitPointCount, splitPoints);
    }

That code doesn't work because there is no guarantee that the sorted tokens are split equally. Is that design close to what you were thinking? Or do I literally need to add calls to remove tokens from sortedTokens and then at each successive recursive call I would get the next power of two worth of tokens until I hit the sortedTokens.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did have literal removal from the list in mind. It seems, however, that we can make your original solution work by making sure we always report the right number of splits, by doing something like this:

        int pos = 0;
        for (int i = 0; i < evenSplitPoints.length; i++)
        {
            int min = pos;
            int max = sortedTokens.length - evenSplitPoints.length + i;
            Token value = evenSplitPoints[i];
            pos = Arrays.binarySearch(sortedTokens, min, max, value);
            if (pos < 0)
                pos = -pos - 1;

            if (pos == min)
            {
                // No left neighbor, so choose the right neighbor
                nodeAlignedSplitPoints[i] = sortedTokens[pos];
                pos++;
            }
            else if (pos == max)
            {
                // No right neighbor, so choose the left neighbor
                // This also means that for all greater indexes we don't have a choice.
                for (; i < evenSplitPoints.length; ++i)
                    nodeAlignedSplitPoints[i] = sortedTokens[pos++ - 1];
            }
            else
            {
                // Check the neighbors
                Token leftNeighbor = sortedTokens[pos - 1];
                Token rightNeighbor = sortedTokens[pos];

                // Choose the nearest neighbor. By convention, prefer left if value is midpoint, but don't
                // choose the same token twice.
                if (leftNeighbor.size(value) <= value.size(rightNeighbor))
                {
                    nodeAlignedSplitPoints[i] = leftNeighbor;
                    // No need to bump pos because we decremented it to find the right split token.
                }
                else
                {
                    nodeAlignedSplitPoints[i] = rightNeighbor;
                    pos++;
                }
            }
        }

Could you run this through your test to see if it works? If not, could you upload the test so that I can play with it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That passed my test, thank you! I'll add some more tests proving that as we add nodes, the split points continue to be aligned

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This algorithm fails to choose the same tokens if we add one shard at a time. For example, give sorted tokens:

[-9193912203636246467, -4846620020038852605, -1955575638654777768, 1845313162618248341, 4442481910831405714, 7317158111889931131]

3 shards chooses split points [-1955575638654777768, 1845313162618248341]
4 shards chooses split points [-4846620020038852605, 1845313162618248341, 4442481910831405714]

Observe that the 3 split points do not contain 2 split points. I think this is acceptable because the shard count is always baseShardCount * 2 ^ n. Is that correct? I'll push up a test shortly that confirms this works with power of 2 growth.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just pushed a commit with testRangeEndsAreFromTokenListAndContainLowerRangeEnds to show that this works for powers of 2 shard counts.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we are only guaranteed that findNearest(x) is a subset of findNearest(y) when x is a subset of y. The logic is that if we couldn't select the closest value to a point, it is only because we have already picked it for another point. This shouldn't be too hard to prove formally if we need to do that (I did check if ChatGPT 4o is smart enough to do it, without success).

}


private Token[] computeUniformSplitPoints(IPartitioner partitioner, int shardCount)
{
var tokenCount = shardCount - 1;
var tokens = new Token[tokenCount];
for (int i = 0; i < tokenCount; i++)
{
var ratio = ((double) i) / shardCount;
tokens[i] = partitioner.split(partitioner.getMinimumToken(), partitioner.getMaximumToken(), ratio);
}
return tokens;
}

private class NodeAlignedShardTracker implements ShardTracker
{
private final int shardCount;
private final Token[] sortedTokens;
private int index = 0;

NodeAlignedShardTracker(int shardCount, Token[] sortedTokens)
{
this.shardCount = shardCount;
this.sortedTokens = sortedTokens;
}

@Override
public Token shardStart()
{
return sortedTokens[index];
}

@Nullable
@Override
public Token shardEnd()
{
return index + 1 < sortedTokens.length ? sortedTokens[index + 1] : null;
}

@Override
public Range<Token> shardSpan()
{
return new Range<>(shardStart(), end());
}

@Override
public double shardSpanSize()
{
var start = sortedTokens[index];
// TODO should this be weighted? I think not because we use the token allocator to get the splits and that
// currently removes our ability to know the weight, but want to check.
return start.size(end());
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* Non-nullable implementation of {@link ShardTracker#shardEnd()}
* @return
*/
private Token end()
{
var end = shardEnd();
return end != null ? end : sortedTokens[index].minValue();
}

@Override
public boolean advanceTo(Token nextToken)
{
var currentEnd = shardEnd();
if (currentEnd == null || nextToken.compareTo(currentEnd) <= 0)
return false;
do
{
index++;
currentEnd = shardEnd();
if (currentEnd == null)
break;
} while (nextToken.compareTo(currentEnd) > 0);
return true;
}

@Override
public int count()
{
return shardCount;
}

@Override
public double fractionInShard(Range<Token> targetSpan)
{
Range<Token> shardSpan = shardSpan();
Range<Token> covered = targetSpan.intersectionNonWrapping(shardSpan);
if (covered == null)
return 0;
if (covered == targetSpan)
return 1;
// TODO confirm this is okay without a weigth and without reference to the sortedLocalRange list
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
double inShardSize = covered.left.size(covered.right);
double totalSize = targetSpan.left.size(targetSpan.right);
return inShardSize / totalSize;
}

@Override
public double rangeSpanned(PartitionPosition first, PartitionPosition last)
{
// TODO how do we take local range ownership into account here? The ShardManagerNodeAware is doing that for
// us, but it seems that this node aware version is possibly off base.
return ShardManagerNodeAware.this.rangeSpanned(first, last);
michaeljmarshall marked this conversation as resolved.
Show resolved Hide resolved
}

@Override
public int shardIndex()
{
return index;
}

@Override
public long shardAdjustedKeyCount(Set<SSTableReader> sstables)
{
// Not sure if this needs a custom implementation yet
return ShardTracker.super.shardAdjustedKeyCount(sstables);
}

@Override
public void applyTokenSpaceCoverage(SSTableWriter writer)
{
// Not sure if this needs a custom implementation yet
ShardTracker.super.applyTokenSpaceCoverage(writer);
}
}
}