Using MapMaker#makeComputingMap to prevent simultaneous RPCs for the same data - java

We have a slow backend server that is getting crushed by load and we'd like the middle-tier Scala server to only have one outstanding request to the backend for each unique lookup.
The backend server only stores immutable data, but upon the addition of new data, the middle-tier servers will request the newest data on behalf of the clients and the backend server has a hard time with the load. The immutable data is cached in memcached using unique keys generated upon the write, but the write rate is high so we get a low memcached hit rate.
One idea I have is to use Google Guava's MapMaker#makeComputingMap() to wrap the actual lookup and after ConcurrentMap#get() returns, the middle-tier will save the result and just delete the key from the Map.
This seems a little wasteful, although the code is very easy to write, see below for an example of what I'm thinking.
Is there a more natural data structure, library or part of Guava that would solve this problem?
import com.google.common.collect.MapMaker
object Test
{
val computer: com.google.common.base.Function[Int,Long] =
{
new com.google.common.base.Function[Int,Long] {
override
def apply(i: Int): Long =
{
val l = System.currentTimeMillis + i
System.err.println("For " + i + " returning " + l)
Thread.sleep(2000)
l
}
}
}
val map =
{
new MapMaker().makeComputingMap[Int,Long](computer)
}
def get(k: Int): Long =
{
val l = map.get(k)
map.remove(k)
l
}
def main(args: Array[String]): Unit =
{
val t1 = new Thread() {
override def run(): Unit =
{
System.err.println(get(123))
}
}
val t2 = new Thread() {
override def run(): Unit =
{
System.err.println(get(123))
}
}
t1.start()
t2.start()
t1.join()
t2.join()
System.err.println(get(123))
}
}

I'm not sure why you implement remove yourself, why not simply have weak or soft values and let the GC clean up for you?
new MapMaker().weakValues().makeComputingMap[Int, Long](computer)

I think what you do is quite reasonable. You only use the structure to get lock-striping on the key, to ensure that accesses to the same key conflict. No worries that you don't need a value mapping per key. ConcurrentHashMap and friends is the only structure in Java libraries+Guava that offers you lock-striping.
This does induce some minor runtime overhead, plus the size of the hashtable which you don't need (which might even grow, if accesses to the same segment pile up and remove() doesn't keep up).
If you want to make it as cheap as possible, you could code some simple lock-striping yourself. Basically an Object[] (or Array[AnyRef] :)) of N locks (N = concurrency level), and you just map the hash of the lookup key into this array, and lock. Another advantage of this is that you really don't have to do hashcode tricks that CHM requires to do, because the latter has to split the hashcode in one part to select the lock, and another for the needs of the hashtable, but you can use the whole of it just for the lock selection.
edit: Sketching my comment below:
val concurrencyLevel = 16
val locks = (for (i <- 0 to concurrencyLevel) yield new AnyRef).toArray
def access(key: K): V = {
val lock = locks(key.hashCode % locks.size)
lock synchronized {
val valueFromCache = cache.lookup(key)
valueFromCache match {
case Some(v) => return v
case None =>
val valueFromBackend = backendServer.lookup(key)
cache.put(key, valueFromBackend)
return valueFromBackend
}
}
}
(Btw, is the toArray call needed? Or the returned IndexSeq is already fast to access by index?)

Related

Spark java.lang.StackOverflowError

I'm using spark in order to calculate the pagerank of user reviews, but I keep getting Spark java.lang.StackOverflowError when I run my code on a big dataset (40k entries). when running the code on a small number of entries it works fine though.
Entry Example :
product/productId: B00004CK40 review/userId: A39IIHQF18YGZA review/profileName: C. A. M. Salas review/helpfulness: 0/0 review/score: 4.0 review/time: 1175817600 review/summary: Reliable comedy review/text: Nice script, well acted comedy, and a young Nicolette Sheridan. Cusak is in top form.
The Code:
public void calculatePageRank() {
sc.clearCallSite();
sc.clearJobGroup();
JavaRDD < String > rddFileData = sc.textFile(inputFileName).cache();
sc.setCheckpointDir("pagerankCheckpoint/");
JavaRDD < String > rddMovieData = rddFileData.map(new Function < String, String > () {
#Override
public String call(String arg0) throws Exception {
String[] data = arg0.split("\t");
String movieId = data[0].split(":")[1].trim();
String userId = data[1].split(":")[1].trim();
return movieId + "\t" + userId;
}
});
JavaPairRDD<String, Iterable<String>> rddPairReviewData = rddMovieData.mapToPair(new PairFunction < String, String, String > () {
#Override
public Tuple2 < String, String > call(String arg0) throws Exception {
String[] data = arg0.split("\t");
return new Tuple2 < String, String > (data[0], data[1]);
}
}).groupByKey().cache();
JavaRDD<Iterable<String>> cartUsers = rddPairReviewData.map(f -> f._2());
List<Iterable<String>> cartUsersList = cartUsers.collect();
JavaPairRDD<String,String> finalCartesian = null;
int iterCounter = 0;
for(Iterable<String> out : cartUsersList){
JavaRDD<String> currentUsersRDD = sc.parallelize(Lists.newArrayList(out));
if(finalCartesian==null){
finalCartesian = currentUsersRDD.cartesian(currentUsersRDD);
}
else{
finalCartesian = currentUsersRDD.cartesian(currentUsersRDD).union(finalCartesian);
if(iterCounter % 20 == 0) {
finalCartesian.checkpoint();
}
}
}
JavaRDD<Tuple2<String,String>> finalCartesianToTuple = finalCartesian.map(m -> new Tuple2<String,String>(m._1(),m._2()));
finalCartesianToTuple = finalCartesianToTuple.filter(x -> x._1().compareTo(x._2())!=0);
JavaPairRDD<String, String> userIdPairs = finalCartesianToTuple.mapToPair(m -> new Tuple2<String,String>(m._1(),m._2()));
JavaRDD<String> userIdPairsString = userIdPairs.map(new Function < Tuple2<String, String>, String > () {
//Tuple2<Tuple2<MovieId, userId>, Tuple2<movieId, userId>>
#Override
public String call (Tuple2<String, String> t) throws Exception {
return t._1 + " " + t._2;
}
});
try {
//calculate pagerank using this https://github.com/apache/spark/blob/master/examples/src/main/java/org/apache/spark/examples/JavaPageRank.java
JavaPageRank.calculatePageRank(userIdPairsString, 100);
} catch (Exception e) {
// TODO Auto-generated catch block
e.printStackTrace();
}
sc.close();
}
I have multiple suggestions which will help you to greatly improve the performance of the code in your question.
Caching: Caching should be used on those data sets which you need to refer to again and again for same/ different operations (iterative algorithms.
An example is RDD.count — to tell you the number of lines in the
file, the file needs to be read. So if you write RDD.count, at
this point the file will be read, the lines will be counted, and the
count will be returned.
What if you call RDD.count again? The same thing: the file will be
read and counted again. So what does RDD.cache do? Now, if you run
RDD.count the first time, the file will be loaded, cached, and
counted. If you call RDD.count a second time, the operation will use
the cache. It will just take the data from the cache and count the
lines, no recomputing.
Read more about caching here.
In your code sample you are not reusing anything that you've cached. So you may remove the .cache from there.
Parallelization: In the code sample, you've parallelized every individual element in your RDD which is already a distributed collection. I suggest you to merge the rddFileData, rddMovieData and rddPairReviewData steps so that it happens in one go.
Get rid of .collect since that brings the results back to the driver and maybe the actual reason for your error.
This problem will occur when your DAG grows big and too many level of transformations happening in your code. The JVM will not be able to hold the operations to perform lazy execution when an action is performed in the end.
Checkpointing is one option. I would suggest to implement spark-sql for this kind of aggregations. If your data is structured, try to load that into dataframes and perform grouping and other mysql functions to achieve this.
When your for loop grows really large, Spark can no longer keep track of the lineage. Enable checkpointing in your for loop to checkpoint your rdd every 10 iterations or so. Checkpointing will fix the problem. Don't forget to clean up the checkpoint directory after.
http://spark.apache.org/docs/latest/streaming-programming-guide.html#checkpointing
Below things fixed stackoverflow error, as others pointed it's because of lineage that spark keeps building, specially when you have loop/iteration in code.
Set checkpoint directory
spark.sparkContext.setCheckpointDir("./checkpoint")
checkpoint dataframe/Rdd you are modifying/operating in iteration
modifyingDf.checkpoint()
Cache Dataframe which are reused in each iteration
reusedDf.cache()

Spark streaming mapWithState timeout delayed?

I expected the new mapWithState API for Spark 1.6+ to near-immediately remove objects that are timed-out, but there is a delay.
I'm testing the API with the adapted version of the JavaStatefulNetworkWordCount below:
SparkConf sparkConf = new SparkConf()
.setAppName("JavaStatefulNetworkWordCount")
.setMaster("local[*]");
JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
ssc.checkpoint("./tmp");
StateSpec<String, Integer, Integer, Tuple2<String, Integer>> mappingFunc =
StateSpec.function((word, one, state) -> {
if (state.isTimingOut())
{
System.out.println("Timing out the word: " + word);
return new Tuple2<String,Integer>(word, state.get());
}
else
{
int sum = one.or(0) + (state.exists() ? state.get() : 0);
Tuple2<String, Integer> output = new Tuple2<String, Integer>(word, sum);
state.update(sum);
return output;
}
});
JavaMapWithStateDStream<String, Integer, Integer, Tuple2<String, Integer>> stateDstream =
ssc.socketTextStream(args[0], Integer.parseInt(args[1]),
StorageLevels.MEMORY_AND_DISK_SER_2)
.flatMap(x -> Arrays.asList(SPACE.split(x)))
.mapToPair(w -> new Tuple2<String, Integer>(w, 1))
.mapWithState(mappingFunc.timeout(Durations.seconds(5)));
stateDstream.stateSnapshots().print();
Together with nc (nc -l -p <port>)
When I type a word into the nc window I see the tuple being printed in the console every second. But it doesn't seem like the timing out message gets printed out 5s later, as expected based on the timeout set. The time it takes for the tuple to expire seems to vary between 5 & 20s.
Am I missing some configuration option, or is the timeout perhaps only performed at the same time as checkpoints?
Once an event times out it's NOT deleted right away, but is only marked for deletion by saving it to a 'deltaMap':
override def remove(key: K): Unit = {
val stateInfo = deltaMap(key)
if (stateInfo != null) {
stateInfo.markDeleted()
} else {
val newInfo = new StateInfo[S](deleted = true)
deltaMap.update(key, newInfo)
}
}
Then, timed out events are collected and sent to the output stream only at checkpoint. That is: events which time out at batch t, will appear in the output stream only at the next checkpoint - by default, after 5 batch-intervals on average, i.e. batch t+5:
override def checkpoint(): Unit = {
super.checkpoint()
doFullScan = true
}
...
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
...
// Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
...
Elements are actually removed only when there are enough of them, and when state map is being serialized - which currently also happens only at checkpoint:
/** Whether the delta chain length is long enough that it should be compacted */
def shouldCompact: Boolean = {
deltaChainLength >= deltaChainThreshold
}
// Write the data in the parent state map while copying the data into a new parent map for
// compaction (if needed)
val doCompaction = shouldCompact
...
By default checkpointing occurs every 10 iterations, thus in the example above every 10 seconds; since your timeout is 5 seconds, events are expected within 5-15 seconds.
EDIT: Corrected and elaborated answer following comments by #YuvalItzchakov
Am I missing some configuration option, or is the timeout perhaps only
performed at the same time as snapshots?
Every time a mapWithState is invoked (with your configuration, around every 1 second), the MapWithStateRDD will internally check for expired records and time them out. You can see it in the code:
// Get the timed out state records, call the mapping function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTimingOutState(state)
val returned = mappingFunction(batchTime, key, None, wrappedState)
mappedData ++= returned
newStateMap.remove(key)
}
}
(Other than time taken to execute each job, it turns out that newStateMap.remove(key) actually only marks files for deletion. See "Edit" for more.)
You have to take into account the time it takes for each stage to be scheduled, and the amount of time it takes for each execution of such a stage to actually take it's turn and run. It isn't accurate because this runs as a distributed systems where other factors can come into play, making your timeout more/less accurate than you expect it to be.
Edit
As #etov rightly points out, newStateMap.remove(key) doesn't actually remove the element from the OpenHashMapBasedStateMap[K, S], but simply mark it for deletion. This is also a reason why you're seeing the expiration time adding up.
The actual relevant piece of code is here:
// Write the data in the parent state map while
// copying the data into a new parent map for compaction (if needed)
val doCompaction = shouldCompact
val newParentSessionStore = if (doCompaction) {
val initCapacity = if (approxSize > 0) approxSize else 64
new OpenHashMapBasedStateMap[K, S](initialCapacity = initCapacity, deltaChainThreshold)
} else { null }
val iterOfActiveSessions = parentStateMap.getAll()
var parentSessionCount = 0
// First write the approximate size of the data to be written, so that readObject can
// allocate appropriately sized OpenHashMap.
outputStream.writeInt(approxSize)
while(iterOfActiveSessions.hasNext) {
parentSessionCount += 1
val (key, state, updateTime) = iterOfActiveSessions.next()
outputStream.writeObject(key)
outputStream.writeObject(state)
outputStream.writeLong(updateTime)
if (doCompaction) {
newParentSessionStore.deltaMap.update(
key, StateInfo(state, updateTime, deleted = false))
}
}
// Write the final limit marking object with the correct count of records written.
val limiterObj = new LimitMarker(parentSessionCount)
outputStream.writeObject(limiterObj)
if (doCompaction) {
parentStateMap = newParentSessionStore
}
If deltaMap should be compacted (marked with the doCompaction variable), then (and only then) is the map cleared from all the deleted instances. How often does that happen? One the delta exceeds the threadshold:
val DELTA_CHAIN_LENGTH_THRESHOLD = 20
Which means the delta chain is longer than 20 items, and there are items that have been marked for deletion.

Process the list of different types - is using scala (or functional programming) more expensive than Java?

First of all, let me be clear that I am very new to Scala and functional programming, so my understanding and implementation may be incorrect or inefficient.
Given a file look like this:
type1 param11 param12 ...
type2 param21 param22 ...
type2 param31 param32 ...
type1 param41 param42 ...
...
Basically, each line starts with the type of an object which can be created by the following parameters in the same line. I'm working an application which goes through each line, creates an object of a given type and returns the list of lists of all the objects.
In Java, my implementation is like this:
public void parse(List[Type1] type1s, List[Type2] type2s, List[String] lines) {
for (String line in lines) {
if (line.startsWith("type1")) {
Type1 type1 = Type1.createObj(line);
type1s.add(type1)l
} else if (line.startsWith("type2")) {
Type2 type2 = Type2.createObj(line);
type2s.add(type2)l
} else { throw new Exception("Unknown type %s".format(line)) }
}
}
In order to do the same thing in Scala, I do this:
def parse(lines: List[String]): (List[Type1], List[Type2]) = {
val type1Lines = lines filter (x => x.startsWith("type1"))
val type2Lines = lines filter (x => x.startsWith("type2"))
val type1s = type1Lines map (x => Type1.createObj(x))
val type2s = type2Lines map (x => Type2.createObj(x))
(type1s, type2s)
}
As I understand, while my Java implementation only goes through the list once, the Scala one has to do it three times: to filter type1, to filter type2 and to create objects from them. Which means the Scala implementation should be slower than the Java one, right? Moreover, the Java implementation is also more memory saving as it only has 3 instances: type1s, type2s and lines. On the other hand, the Scala one has 5: lines, type1Lines, type2Lines, type1s and type2s.
So my questions are:
Is there a better way to re-write my Scala implementation so that the list is iterated only once?
Using immutable object means a new object is create every time, does
it mean functional programming requires more memory than others?
Updated: I create a simple test to demonstrate that the Scala program is slower: a program receives a list of String with size = 1000000. It iterate through a list and check each item, if an item starts with "type1", it adds 1 to a list named type1s, otherwise, it adds 2 to another list named type2s.
Java implementation:
public static void test(List<String> lines) {
System.out.println("START");
List<Integer> type1s = new ArrayList<Integer>();
List<Integer> type2s = new ArrayList<Integer>();
long start = System.currentTimeMillis();
for (String l : lines) {
if (l.startsWith("type1")) {
type1s.add(1);
} else {
type2s.add(2);
}
}
long end = System.currentTimeMillis();
System.out.println(String.format("END after %s milliseconds", end - start));
}
Scala implementation:
def test(lines: List[String]) = {
println("START")
val start = java.lang.System.currentTimeMillis()
val type1Lines = lines filter (x => x.startsWith("type1"))
val type2Lines = lines filter (x => x.startsWith("type2"))
val type1s = type1Lines map (x => 1)
val type2s = type2Lines map (x => 2)
val end = java.lang.System.currentTimeMillis()
println("END after %s milliseconds".format(end - start))
}
}
Averagely, the Java application took 44 milliseconds while the Scala one needed 200 milliseconds.
object ScalaTester extends App {
val random = new Random
test((0 until 1000000).toList map {_ => s"type${random nextInt 10}"})
def test(lines: List[String]) {
val start = Platform.currentTime
val m = lines groupBy {
case s if s startsWith "type1" => "type1"
case s if s startsWith "type2" => "type2"
case _ => ""
}
println(s"Total type1: ${m("type1").size}; Total type2: ${m("type2").size}; time=${Platform.currentTime - start}")
}
}
The real advantage of Scala (and functional programming in general) is the ability to process data transforming one structures into another.
Of course you can combine mappings, flatMappings, filters, groups and so forth in a single code line. It results to a single data collection.
You may do it one after another creating new collections each time. And this produces a little overhead indeed. But does one care about it? Even though you create excessive collections Scala-style programming helps you design parallel oriented code (as Niklas already mentioned) and prevents you from very elusive side-effects errors that imperative-style programming is prone to

Getting sub-scores for sub-queries in Lucene

I have constructed a query that's essentially a weighted sum of other queries:
val query = new BooleanQuery
for ((subQuery, weight) <- ...) {
subQuery.setBoost(weight)
query.add(subQuery, BooleanClause.Occur.MUST)
}
When I query the index, I get back documents with the overall scores. This is good, but I also need to know what the sub-scores for each of the sub-queries were. How can I get those? Here's what I'm doing now:
for (scoreDoc <- searcher.search(query, nHits).scoreDocs) {
val score = scoreDoc.score
val subScores = subQueries.map { subQuery =>
val weight = searcher.createNormalizedWeight(subQuery)
val scorer = weight.scorer(reader, true, true)
scorer.advance(scoreDoc.doc)
scorer.score
}
}
I think this gives me the right scores, but it seems wasteful to advance to and re-score the document when I know it's already been scored as part of the overall score.
Is there a more efficient way to get those sub-scores?
[My code here is in Scala, but feel free to respond in Java if that's easier.]
EDIT: Here's what things look like after following Robert Muir's suggestion.
The query:
val query = new BooleanQuery
for ((subQuery, weight) <- ...) {
val weightedQuery = new BoostedQuery(subQuery, new ConstValueSource(weight))
query.add(weightedQuery, BooleanClause.Occur.MUST)
}
The search:
val collector = new DocScoresCollector(nHits)
searcher.search(query, collector)
for (docScores <- collector.getDocSubScores) {
...
}
The collector:
class DocScoresCollector(maxSize: Int) extends Collector {
var scorer: Scorer = null
var subScorers: Seq[Scorer] = null
val priorityQueue = new DocScoresPriorityQueue(maxSize)
override def setScorer(scorer: Scorer): Unit = {
this.scorer = scorer
// a little reflection hackery is required here because of a bug in
// BoostedQuery's scorer's getChildren method
// https://issues.apache.org/jira/browse/LUCENE-4261
this.subScorers = scorer.getChildren.asScala.map(childScorer =>
childScorer.child ...some hackery... ).toList
}
override def acceptsDocsOutOfOrder: Boolean = false
override def collect(doc: Int): Unit = {
this.scorer.advance(doc)
val score = this.scorer.score
val subScores = this.subScorers.map(_.score)
priorityQueue.insertWithOverflow(DocScores(doc, score, subScores))
}
override def setNextReader(context: AtomicReaderContext): Unit = {}
def getDocSubScores: Seq[DocScores] = {
val buffer = Buffer.empty[DocScores]
while (this.priorityQueue.size > 0) {
buffer += this.priorityQueue.pop
}
buffer
}
}
case class DocScores(doc: Int, score: Float, subScores: Seq[Float])
class DocScoresPriorityQueue(maxSize: Int) extends PriorityQueue[DocScores](maxSize) {
def lessThan(a: DocScores, b: DocScores) = a.score < b.score
}
There is a scorer navigation API: the basic idea is you write a collector and in its setScorer method, where normally you would save a reference to that Scorer to later score() each hit, you can now walk the tree of that Scorer's subscorers and so on.
Note that Scorers have pointers back to the Weight that created them, and the Weight back to the Query.
Using all of this, you can stash away references to the subscorers you care about in your setScorer method, e.g. all the ones created from TermQueries. Then when scoring hits, you could and investigate things like the freq() and score() of those nodes in your collector.
In the 3.x series this is a visitor API limited to boolean relationships, in the 4.x series (as of now only an alpha release), you can just get the child+relationship of each subscorer, so it can work with arbitrary queries (including custom ones you write or whatever).
Caveats:
you will need to return false from acceptsDocsOutOfOrder in your collector, as your collector requires this document-at-a-time processing for this to work.
you probably want a bugfix branch of the 3.6 series (http://svn.apache.org/repos/asf/lucene/dev/branches/lucene_solr_3_6/) or a snapshot of 4.x (http://svn.apache.org/repos/asf/lucene/dev/branches/branch_4x/). This is because this functionality generally didnt work since disjunctions (OR queries) always set their subscorers 'one doc ahead' of the current document until some things were fixed last week, and those fixes didnt make it in time for 3.6.1. See https://issues.apache.org/jira/browse/LUCENE-3505 for more details.
There aren't really any good examples, except some simple tests that sum up the term frequencies of all the leaf nodes (see below)
Tests:
4.x series: http://svn.apache.org/repos/asf/lucene/dev/branches/branch_4x/lucene/core/src/test/org/apache/lucene/search/TestBooleanQueryVisitSubscorers.java
3.x series: http://svn.apache.org/repos/asf/lucene/dev/branches/lucene_solr_3_6/lucene/core/src/test/org/apache/lucene/search/TestBooleanQueryVisitSubscorers.java

My ConcurrentHashmap's value type is List,how to make appending to that list thread safe?

My class extends from ConcurrentHashmap[String,immutable.List[String]]
and it has 2 methods :
def addEntry(key: String, newList: immutable.List[String]) = {
...
//if key exist,appending the newList to the exist one
//otherwise set the newList as the value
}
def resetEntry(key: String): Unit = {
this.remove(key)
}
in order to make the addEntry method thread safe,I tried :
this.get(key).synchronized{
//append or set here
}
but that will raise null pointer exception if key does not exist,and use putIfAbsent(key, new immutable.List()) before synchronize won't work cause after putIfAbsent and before goes into synchronized block,the key may be removed by resetEntry.
make addEntry and resetEntry both synchronized method will work but the lock is too large
So, what could I do?
ps.this post is similiar with How to make updating BigDecimal within ConcurrentHashMap thread safe while plz help me figure out how to code other than general guide
--update--
checkout https://stackoverflow.com/a/34309186/404145, solved this after almost 3+ years later.
Instead of removing the entry, can you simply clear it? You can still use a synchronized list and ensure atomicity.
def resetEntry(key: String, currentBatchSize: Int): Unit = {
this.get(key).clear();
}
This works with the assumption that each key has an entry. For example if this.get(key)==null You would want to insert a new sychronizedList which should act as a clear as well.
After more than 3 years, I think now I can answer my question.
The original problem is:
I get a ConcurrentHashMap[String, List], many threads are appending values to it, how can I make it thread-safe?
Make addEntry() synchronized will work, right?
synchronize(map.get(key)){
map.append(key, value)
}
In most cases yes except when map.get(key) is null, which will cause NullPointerException.
So what about adding map.putIfAbsent(key, new List) like this:
map.putIfAbsent(key, new List)
synchronize(map.get(key)){
map.append(key, value)
}
Better now, but if after putIfAbsent() another thread called resetEntry(), we will see NullPointerException again.
Make addEntry and resetEntry both synchronized method will work but the lock is too big.
So what about MapEntry Level Lock when appending and Map Level Lock when resetting?
Here comes the ReentrantReadWriteLock:
When calling addEntry(), we acquire a share lock of the map, that makes appending as concurrently as possible, and when calling resetEntry(), we acquire an exclusive lock to make sure that no other threads are changing the map at the same time.
The code looks like this:
class MyMap extends ConcurrentHashMap{
val lock = new ReentrantReadWriteLock();
def addEntry(key: String, newList: immutable.List[String]) = {
lock.readLock.lock()
//if key exist,appending the newList to the exist one
//otherwise set the newList as the value
this.putIfAbsent(key, new List())
this(key).synchronized{
this(key) += newList
}
lock.readLock.unlock()
}
def resetEntry(key: String, currentBatchSize: Int): Unit = {
lock.writeLock.lock()
this.remove(key)
lock.writeLock.unlock()
}
}
You can try a method inspired by the CAS (Compare and Swap) process:
(in pseudo-java-scala-code, as my Scala is still in its infancy)
def addEntry(key: String, newList: immutable.List[String]) = {
val existing = putIfAbsent(key, newList);
if (existing != null) {
synchronized(existing) {
if (get(key) == existing) { // ask again for the value within the synchronized block to ensure consistence. This is the compare part of CAS
return put(key,existing ++ newList); // Swap the old value by the new
} else {
throw new ConcurrentModificationException(); // how else mark failure?
}
}
}
return existing;
}

Categories

Resources