Skip to content

Commit

Permalink
Fix findTokenAlignedSplitPoints; add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Sep 17, 2024
1 parent 038e539 commit 778d47b
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -129,51 +129,43 @@ private Token[] findTokenAlignedSplitPoints(Token[] sortedTokens, int splitPoint
int pos = 0;
for (int i = 0; i < evenSplitPoints.length; i++)
{
int min = pos;
int max = sortedTokens.length - evenSplitPoints.length + i;
Token value = evenSplitPoints[i];
pos = Arrays.binarySearch(sortedTokens, pos, sortedTokens.length, value);
pos = Arrays.binarySearch(sortedTokens, min, max, value);
if (pos < 0)
pos = -pos - 1;

if (pos >= 0)
if (pos == min)
{
// Exact match found
// No left neighbor, so choose the right neighbor
nodeAlignedSplitPoints[i] = sortedTokens[pos];
pos++;
}
else if (pos == max)
{
// No right neighbor, so choose the left neighbor
// This also means that for all greater indexes we don't have a choice.
for (; i < evenSplitPoints.length; ++i)
nodeAlignedSplitPoints[i] = sortedTokens[pos++ - 1];
}
else
{
// pos is -(insertion point) - 1, so calculate the insertion point
pos = -pos - 1;
// Check the neighbors
Token leftNeighbor = sortedTokens[pos - 1];
Token rightNeighbor = sortedTokens[pos];

if (pos == 0)
// Choose the nearest neighbor. By convention, prefer left if value is midpoint, but don't
// choose the same token twice.
if (leftNeighbor.size(value) <= value.size(rightNeighbor))
{
// No left neighbor, so choose the right neighbor
nodeAlignedSplitPoints[i] = sortedTokens[pos];
pos++;
}
else if (pos == sortedTokens.length)
{
// todo assert we're at the end?
// No right neighbor, so choose the left neighbor
nodeAlignedSplitPoints[i] = sortedTokens[pos - 1];
pos++;
nodeAlignedSplitPoints[i] = leftNeighbor;
// No need to bump pos because we decremented it to find the right split token.
}
else
{
// Check the neighbors
Token leftNeighbor = sortedTokens[pos - 1];
Token rightNeighbor = sortedTokens[pos];

// Choose the nearest neighbor. By convention, prefer left if value is midpoint, but don't
// choose the same token twice.
if (value.size(leftNeighbor) <= value.size(rightNeighbor) && !leftNeighbor.equals(nodeAlignedSplitPoints[i - 1]))
{
nodeAlignedSplitPoints[i] = leftNeighbor;
// No need to bump pos because we decremented it to find the right split token.
}
else
{
nodeAlignedSplitPoints[i] = rightNeighbor;
pos++;
}
nodeAlignedSplitPoints[i] = rightNeighbor;
pos++;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.cassandra.db.compaction;

import java.net.UnknownHostException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

import org.junit.Test;

import org.apache.cassandra.config.Config;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.dht.Murmur3Partitioner;
import org.apache.cassandra.dht.Token;
import org.apache.cassandra.dht.tokenallocator.TokenAllocation;
import org.apache.cassandra.locator.AbstractReplicationStrategy;
import org.apache.cassandra.locator.InetAddressAndPort;
import org.apache.cassandra.locator.NetworkTopologyStrategy;
import org.apache.cassandra.locator.RackInferringSnitch;
import org.apache.cassandra.locator.TokenMetadata;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;

public class ShardManagerNodeAwareTest
{

@Test
public void testRangeEndsForShardCountEqualtToNumTokensPlusOne() throws UnknownHostException
{
for (int numTokens = 1; numTokens < 32; numTokens++)
{
var rs = buildStrategy(numTokens, 1, 1, 1);
var expectedTokens = rs.getTokenMetadata().sortedTokens();
var shardManager = new ShardManagerNodeAware(rs);

var shardCount = numTokens + 1;
var iterator = shardManager.boundaries(shardCount);
assertEquals(Murmur3Partitioner.instance.getMinimumToken(), iterator.shardStart());
var actualTokens = new ArrayList<Token>();
for (Token end = iterator.shardEnd(); end != null; end = iterator.shardEnd())
{
assertFalse(iterator.advanceTo(end));
assertTrue(iterator.advanceTo(end.nextValidToken()));
actualTokens.add(end);
}
assertEquals(expectedTokens, actualTokens);
}
}

@Test
public void testRangeEndsAreFromTokenListAndContainLowerRangeEnds() throws UnknownHostException
{
for (int nodeCount = 1; nodeCount <= 6; nodeCount++)
{
for (int numTokensPerNode = 1; numTokensPerNode < 16; numTokensPerNode++)
{
// Confirm it works for multiple base shard counts.
for (int baseShardCount = 1; baseShardCount <= 3; baseShardCount++)
{
// Testing with 1 rack, nodeCount nodes, and rf 1.
var rs = buildStrategy(numTokensPerNode, 1, nodeCount, 1);
var initialSplitPoints = rs.getTokenMetadata().sortedTokens();
// Confirm test set up is correct.
assertEquals(numTokensPerNode * nodeCount, initialSplitPoints.size());
var shardManager = new ShardManagerNodeAware(rs);

// The tokens for one level lower.
var lowerTokens = new ArrayList<Token>();
var tokenLimit = numTokensPerNode * nodeCount * 8;
for (int shardExponent = 0; baseShardCount * Math.pow(2, shardExponent) <= tokenLimit; shardExponent++)
{
var shardCount = baseShardCount * (int) Math.pow(2, shardExponent);
var iterator = shardManager.boundaries(shardCount);
assertEquals(Murmur3Partitioner.instance.getMinimumToken(), iterator.shardStart());
assertEquals(shardCount, iterator.count());
var actualSplitPoints = new ArrayList<Token>();
var shardSpanSize = 0d;
var index = -1;
for (Token end = iterator.shardEnd(); end != null; end = iterator.shardEnd())
{
shardSpanSize += iterator.shardSpanSize();
assertEquals(index++, iterator.shardIndex());
assertFalse(iterator.advanceTo(end));
assertTrue(iterator.advanceTo(end.nextValidToken()));
actualSplitPoints.add(end);
}
// Need to add the last shard span size because we exit the above loop before adding it.
shardSpanSize += iterator.shardSpanSize();
// Confirm the shard span size adds to about 1
assertEquals(1d, shardSpanSize, 0.001);

// If we have more split points than the initialSplitPoints, we had to compute additional
// tokens, so the best we can do is confirm containment.
if (actualSplitPoints.size() >= initialSplitPoints.size())
assertTrue(actualSplitPoints + " does not contain " + initialSplitPoints,
actualSplitPoints.containsAll(initialSplitPoints));
else
assertTrue(initialSplitPoints + " does not contain " + actualSplitPoints,
initialSplitPoints.containsAll(actualSplitPoints));

// Higher tokens must always contain lower tokens.
assertTrue(actualSplitPoints + " does not contain " + lowerTokens,
actualSplitPoints.containsAll(lowerTokens));
lowerTokens = actualSplitPoints;
}
}
}
}
}


private AbstractReplicationStrategy buildStrategy(int numTokens, int numRacks, int numNodes, int rf) throws UnknownHostException
{
DatabaseDescriptor.setPartitionerUnsafe(Murmur3Partitioner.instance);
DatabaseDescriptor.setEndpointSnitch(new RackInferringSnitch());
var config = new Config();
config.num_tokens = numTokens;
DatabaseDescriptor.setConfig(config);
var tokenMetadata = new TokenMetadata();
var snitch = DatabaseDescriptor.getEndpointSnitch();
var dc = DatabaseDescriptor.getEndpointSnitch().getLocalDatacenter();
// Configure rf
var options = Map.of(dc, Integer.toString(rf));
var networkTopology = new NetworkTopologyStrategy("0", tokenMetadata, snitch, options);

for (int i = 0; i < numRacks; i++)
generateFakeEndpoints(tokenMetadata, networkTopology, 1, numNodes, numTokens, dc, Integer.toString(i));

return networkTopology;
}

// Generates endpoints and adds them to the tmd and the rs.
private List<Token> generateFakeEndpoints(TokenMetadata tmd, AbstractReplicationStrategy rs, int firstNodeId, int lastNodId, int vnodes, String dc, String rack) throws UnknownHostException
{
System.out.printf("Adding nodes %d through %d to dc=%s, rack=%s.%n", firstNodeId, lastNodId, dc, rack);
var result = new ArrayList<Token>();
for (int i = firstNodeId; i <= lastNodId; i++)
{
// leave .1 for myEndpoint
InetAddressAndPort addr = InetAddressAndPort.getByName("127." + dc + '.' + rack + '.' + (i + 1));
var tokens = TokenAllocation.allocateTokens(tmd, rs, addr, vnodes);
// TODO why don't we need addBootstrapTokens here? The test only passes with updateNormalTokens.
// tmd.addBootstrapTokens(tokens, addr);
tmd.updateNormalTokens(tokens, addr);
result.addAll(tokens);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,6 @@ private int[] getShardBoundaries(int numShards, List<Token> diskBoundaries, Sort
when(db.getPositions()).thenReturn(diskBoundaries);

var rs = Mockito.mock(AbstractReplicationStrategy.class);

// todo use OfflineTokenAllocator to make this test work for the new class
final ShardTracker shardTracker = ShardManager.create(db, rs, false)
.boundaries(numShards);
IntArrayList list = new IntArrayList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ public void testTokenAllocationForMultiNodeMultiRack() throws UnknownHostExcepti
private List<Token> generateFakeEndpoints(TokenMetadata tmd, AbstractReplicationStrategy rs, int firstNodeId, int lastNodId, int vnodes, String dc, String rack) throws UnknownHostException
{
System.out.printf("Adding nodes %d through %d to dc=%s, rack=%s.%n", firstNodeId, lastNodId, dc, rack);
IPartitioner p = tmd.partitioner;
var result = new ArrayList<Token>();
for (int i = firstNodeId; i <= lastNodId; i++)
{
Expand Down

1 comment on commit 778d47b

@cassci-bot
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Build rejected: 4 NEW test failure(s) in 1 builds., Build 1: ran 17581 tests with 10 failures and 128 skipped.
Butler analysis done on ds-cassandra-pr-gate/node-aware-shard-manager vs last 16 runs of ds-cassandra-build-nightly/main.
org.apache.cassandra.index.sai.cql.QueryWriteLifecycleTest.testWriteLifecycle[aa_CompoundKeyDataModel{primaryKey=p, c}]: test is constantly failing. No failures on upstream;
branch story: [F] vs upstream: [++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++]; [NEW]
org.apache.cassandra.index.sai.cql.TinySegmentQueryWriteLifecycleTest.testWriteLifecycle[aa_BaseDataModel{primaryKey=p}]: test is constantly failing. No failures on upstream;
branch story: [F] vs upstream: [++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++]; [NEW]
org.apache.cassandra.utils.binlog.BinLogTest.testTruncationReleasesLogSpace: test is constantly failing. No failures on upstream;
branch story: [F] vs upstream: [++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++]; [NEW]
org.apache.cassandra.index.sai.cql.VectorSiftSmallTest.testMultiSegmentBuild: test is constantly failing. No failures on upstream;
branch story: [F] vs upstream: [++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++]; [NEW]
butler comparison

Please sign in to comment.