diff --git a/src/main/scala/io/iohk/iodb/LSMStore.scala b/src/main/scala/io/iohk/iodb/LSMStore.scala index 7efe309..3b4f6d3 100644 --- a/src/main/scala/io/iohk/iodb/LSMStore.scala +++ b/src/main/scala/io/iohk/iodb/LSMStore.scala @@ -153,7 +153,8 @@ class LSMStore( } var counter = 0 - override def get(key: K): V = { + + override def get(key: K): Option[V] = { lock.readLock().lock() try { counter += 1 @@ -164,11 +165,11 @@ class LSMStore( .getValue val mainLogVersion = lastVersion val shardVersion = shard.lastVersion - if(mainLogVersion!=shardVersion){ + if (mainLogVersion != shardVersion) { //some entries were not sharded yet, try main log - val ret = mainLog.get(key=key, versionId = mainLogVersion, stopAtVersion = shardVersion) - if(ret!=null) - return ret.getOrElse(null); //null is for tombstones found in main log + val ret = mainLog.get(key = key, versionId = mainLogVersion, stopAtVersion = shardVersion) + if (ret != null) + return ret } return shard.get(key) @@ -178,7 +179,7 @@ class LSMStore( } /** gets value from sharded log, ignore main log */ - protected[iodb] def getFromShard(key: K): V = { + protected[iodb] def getFromShard(key: K): Option[V] = { lock.readLock().lock() try { val shard = shards.lastEntry().getValue.floorEntry(key).getValue diff --git a/src/main/scala/io/iohk/iodb/LogStore.scala b/src/main/scala/io/iohk/iodb/LogStore.scala index 82daf7d..72a8769 100644 --- a/src/main/scala/io/iohk/iodb/LogStore.scala +++ b/src/main/scala/io/iohk/iodb/LogStore.scala @@ -218,16 +218,13 @@ class LogStore( iter.toBuffer } - def get(key: K): V = { - val v = get(key, lastVersion) - if (v == null || v.isEmpty) - return null - v.get + def get(key: K): Option[V] = { + return get(key, lastVersion) } protected[iodb] def get(key: K, versionId: Long, stopAtVersion: Long = 0): Option[V] = { if (files.isEmpty) - return null + return None val versions = if (stopAtVersion > 0) files.subMap(versionId, true, stopAtVersion, false).asScala @@ -240,9 +237,9 @@ class LogStore( if (ret != null) return Some(ret) // value was found if (logFile.isMerged) - return null //contains all versions, will not be found in next versions + return None //contains all versions, will not be found in next versions } - null + None } protected def versionGet(logFile: LogFile, key: K): V = { diff --git a/src/main/scala/io/iohk/iodb/Store.scala b/src/main/scala/io/iohk/iodb/Store.scala index 1634e78..a654a7a 100644 --- a/src/main/scala/io/iohk/iodb/Store.scala +++ b/src/main/scala/io/iohk/iodb/Store.scala @@ -29,7 +29,16 @@ trait Store { * @param key to lookup * @return value associated with key or null */ - def get(key: K): V + def get(key: K): Option[V] + + /** Returns value associated with the key, or defualt value from user + */ + def getOrElse(key: K, default: => V): V = get(key).getOrElse(default) + + /** returns value associated with the key or throws `NoSuchElementException` */ + def apply(key: K): V = getOrElse(key, { + throw new NoSuchElementException() + }) /** * Batch get. @@ -43,9 +52,9 @@ trait Store { * @param keys keys to loopup * @return iterable over key-value pairs found in store */ - def get(keys: Iterable[K]): Iterable[(K, V)] = { - val ret = scala.collection.mutable.ArrayBuffer.empty[(K, V)] - get(keys, (key: K, value: V) => + def get(keys: Iterable[K]): Iterable[(K, Option[V])] = { + val ret = scala.collection.mutable.ArrayBuffer.empty[(K, Option[V])] + get(keys, (key: K, value: Option[V]) => ret += ((key, value)) ) ret @@ -56,14 +65,13 @@ trait Store { * * Finds all keys from given iterable. * Results are passed to callable consumer. - * If key is not found, null value is passed to callable consumer. * * It uses lattest (most recent) version available in store * * @param keys keys to lookup * @param consumer callback method to consume results */ - def get(keys: Iterable[K], consumer: (K, V) => Unit): Unit = { + def get(keys: Iterable[K], consumer: (K, Option[V]) => Unit): Unit = { for (key <- keys) { val value = get(key) consumer(key, value) diff --git a/src/test/scala/io/iohk/iodb/LSMStoreTest.scala b/src/test/scala/io/iohk/iodb/LSMStoreTest.scala index 2c691d3..0a18803 100644 --- a/src/test/scala/io/iohk/iodb/LSMStoreTest.scala +++ b/src/test/scala/io/iohk/iodb/LSMStoreTest.scala @@ -46,7 +46,7 @@ class LSMStoreTest extends TestWithTempDir { assert(!lastFiles.logFile.exists()) assert(store.mainLog.files.size == 2) - assert(store.get(key) == fromLong(2)) + assert(store.get(key) == Some(fromLong(2))) store.close() } @@ -77,7 +77,7 @@ class LSMStoreTest extends TestWithTempDir { assert(!lastFiles.logFile.exists()) assert(store.mainLog.files.size == 2) - assert(store.get(key) == fromLong(2)) + assert(store(key) == fromLong(2)) store.close() } @@ -132,7 +132,7 @@ class LSMStoreTest extends TestWithTempDir { assert(!lastFiles.logFile.exists()) assert(store.mainLog.files.size == 2) - assert(store.get(key) == fromLong(2)) + assert(store(key) == fromLong(2)) //ensure shard layout was restored assert(store.shards.size() == 1) @@ -175,7 +175,7 @@ class LSMStoreTest extends TestWithTempDir { splitSize = 1024) for (i <- 1 until keyCount) { - val value = store.get(fromLong(i)) + val value = store(fromLong(i)) assert(value == fromLong((commitCount - 1) * i)) } store.close() diff --git a/src/test/scala/io/iohk/iodb/LogShardTest.scala b/src/test/scala/io/iohk/iodb/LogShardTest.scala index 40fae31..105a6e3 100644 --- a/src/test/scala/io/iohk/iodb/LogShardTest.scala +++ b/src/test/scala/io/iohk/iodb/LogShardTest.scala @@ -27,7 +27,7 @@ class LogShardTest extends TestWithTempDir { val key = TestUtils.fromLong(i) s.update(TestUtils.fromLong(i), Nil, List((key, key))) s.taskShardLogForce() - assert(s.getFromShard(key) === key) + assert(s.getFromShard(key) === Some(key)) s.close() } } diff --git a/src/test/scala/io/iohk/iodb/LogStoreTest.scala b/src/test/scala/io/iohk/iodb/LogStoreTest.scala index 8be3318..12db413 100644 --- a/src/test/scala/io/iohk/iodb/LogStoreTest.scala +++ b/src/test/scala/io/iohk/iodb/LogStoreTest.scala @@ -20,8 +20,7 @@ class LogStoreTest extends TestWithTempDir { Seq.empty, s.map { a => (a, a) }) for (a <- s) { - val a2 = store.get(a) - assert(a == a2) + assert(Some(a) == store.get(a)) } store.close() } @@ -52,7 +51,7 @@ class LogStoreTest extends TestWithTempDir { def checkExists(version: Long) = { for (i <- 1L until 100) { val b = TestUtils.fromLong(i) - assert(b == store.get(b)) + assert(Some(b) == store.get(b)) assert((i == version) == LogStore.logFile(i, dir = dir, filePrefix = filePrefix, isMerged = true).exists()) assert((i > version) == LogStore.logFile(i, dir = dir, filePrefix = filePrefix).exists()) diff --git a/src/test/scala/io/iohk/iodb/StoreBurnTest.scala b/src/test/scala/io/iohk/iodb/StoreBurnTest.scala index e650960..dcb3abc 100644 --- a/src/test/scala/io/iohk/iodb/StoreBurnTest.scala +++ b/src/test/scala/io/iohk/iodb/StoreBurnTest.scala @@ -21,7 +21,7 @@ abstract class StoreBurnTest extends TestWithTempDir { keys.foreach { it => if (it != store.get(it)) store.get(it) - assert(it == store.get(it)) + assert(it == store(it)) } val newKeys = (0 until 10000).map(i => TestUtils.randomA()) @@ -50,7 +50,7 @@ abstract class StoreBurnTest extends TestWithTempDir { while (System.currentTimeMillis() < endTime) { keys.foreach { it => - assert(it == store.get(it)) + assert(it == store(it)) } val newKeys = (0 until 10000).map(i => TestUtils.randomA()) @@ -87,8 +87,7 @@ abstract class StoreBurnTest extends TestWithTempDir { var version = 1 while (System.currentTimeMillis() < endTime) { keys.foreach { it => - - assert(it == store.get(it)) + assert(it == store(it)) } val newKeys = (0 until 10000).map(i => TestUtils.randomA()) @@ -121,7 +120,7 @@ abstract class StoreBurnTest extends TestWithTempDir { while (System.currentTimeMillis() < endTime) { keys.foreach { it => - assert(it === store.get(it)) + assert(it == store(it)) } val newKeys = (0 until 10000).map(i => TestUtils.randomA()) diff --git a/src/test/scala/io/iohk/iodb/StoreTest.scala b/src/test/scala/io/iohk/iodb/StoreTest.scala index 1d532a7..8c63ec0 100644 --- a/src/test/scala/io/iohk/iodb/StoreTest.scala +++ b/src/test/scala/io/iohk/iodb/StoreTest.scala @@ -34,8 +34,8 @@ abstract class StoreTest extends TestWithTempDir { assert(countFiles() === 1 * numberOfFilesPerUpdate) assert(file(1).exists()) assert(store.lastVersionID === v1) - assert(a(1) === store.get(a(0))) - assert(store.get(a(1)) === null) + assert(Some(a(1)) === store.get(a(0))) + assert(store.get(a(1)) === None) store.update(v2, List(a(0)), List.empty) @@ -43,15 +43,15 @@ abstract class StoreTest extends TestWithTempDir { assert(file(2).exists()) assert(store.lastVersionID === v2) - assert(store.get(a(0)) === null) + assert(store.get(a(0)) === None) store.rollback(v1) assert(countFiles() === 1 * numberOfFilesPerUpdate) assert(file(1).exists()) assert(store.lastVersionID === v1) - assert(a(1) === store.get(a(0))) - assert(store.get(a(1)) === null) + assert(a(1) === store(a(0))) + assert(store.get(a(1)) === None) store.close() } @@ -62,7 +62,7 @@ abstract class StoreTest extends TestWithTempDir { store = makeStore(dir) assert(v3 === store.lastVersionID) - assert(a(1) === store.get(a(0))) + assert(Some(a(1)) === store.get(a(0))) store.close() } @@ -75,13 +75,13 @@ abstract class StoreTest extends TestWithTempDir { store.update(v2, List(a(0)), List((null, a(1)))) } assert(v1 === store.lastVersionID) - assert(a(1) === store.get(a(0))) + assert(Some(a(1)) === store.get(a(0))) intercept[NullPointerException] { store.update(v3, List(null), List((a(0), a(2)))) } assert(v1 === store.lastVersionID) - assert(a(1) === store.get(a(0))) + assert(a(1) === store(a(0))) store.close() } @@ -95,7 +95,7 @@ abstract class StoreTest extends TestWithTempDir { store.update(v2, List(a(0)), List((wrongKey, a(1)))) } assert(v1 === store.lastVersionID) - assert(a(1) === store.get(a(0))) + assert(Some(a(1)) === store.get(a(0))) store.close() } } diff --git a/src/test/scala/io/iohk/iodb/bench/RocksStore.scala b/src/test/scala/io/iohk/iodb/bench/RocksStore.scala index 089881d..ec6dc01 100644 --- a/src/test/scala/io/iohk/iodb/bench/RocksStore.scala +++ b/src/test/scala/io/iohk/iodb/bench/RocksStore.scala @@ -32,9 +32,9 @@ class RocksStore(val dir: File) extends Store { /** returns value associated with key */ - override def get(key: K): V = { + override def get(key: K): Option[V] = { val ret = db.get(key.data) - if (ret == null) null else ByteArrayWrapper(ret) + if (ret == null) None else Some(ByteArrayWrapper(ret)) } /** returns versionID from last update, used when Scorex starts */