aboutsummaryrefslogtreecommitdiff
path: root/db/memtable.cu
diff options
context:
space:
mode:
Diffstat (limited to 'db/memtable.cu')
-rw-r--r--db/memtable.cu246
1 files changed, 245 insertions, 1 deletions
diff --git a/db/memtable.cu b/db/memtable.cu
index 241a35b..51ca578 100644
--- a/db/memtable.cu
+++ b/db/memtable.cu
@@ -20,7 +20,6 @@ static Slice GetLengthPrefixedSlice(const char* data) {
MemTable::MemTable(const InternalKeyComparator& comparator)
: comparator_(comparator), refs_(0) {
-
}
MemTable::~MemTable() { assert(refs_ == 0); }
@@ -266,4 +265,249 @@ bool MemTable::Get(const LookupKey& key, std::string* value, Status* s) {
return false;
}
+__device__ const char* GetVarint32PtrFallbackCuda(const char* p, const char* limit,
+ uint32_t* value) {
+ uint32_t result = 0;
+ for (uint32_t shift = 0; shift <= 28 && p < limit; shift += 7) {
+ uint32_t byte = *(reinterpret_cast<const uint8_t*>(p));
+ p++;
+ if (byte & 128) {
+ // More bytes are present
+ result |= ((byte & 127) << shift);
+ } else {
+ result |= (byte << shift);
+ *value = result;
+ return reinterpret_cast<const char*>(p);
+ }
+ }
+ return nullptr;
+}
+
+__device__ inline const char* GetVarint32PtrCuda(const char* p, const char* limit,
+ uint32_t* value) {
+ if (p < limit) {
+ uint32_t result = *(reinterpret_cast<const uint8_t*>(p));
+ if ((result & 128) == 0) {
+ *value = result;
+ return p + 1;
+ }
+ }
+ return GetVarint32PtrFallbackCuda(p, limit, value);
+}
+
+struct SizedSlice {
+ const char * ptr;
+ size_t len;
+
+ __device__ __host__ explicit SizedSlice(const char * ptr, size_t len) {
+ this->ptr = ptr;
+ this->len = len;
+ }
+};
+
+__device__ static SizedSlice GetLengthPrefixedSliceCuda(const char* data) {
+ uint32_t len;
+ const char* p = data;
+ p = GetVarint32PtrCuda(p, p + 5, &len); // +5: we assume "p" is not corrupted
+ return SizedSlice(p, len);
+}
+
+
+__device__ MemTableCuda::MemTableCuda(const InternalKeyComparator& comparator)
+ : comparator_(comparator), refs_(0), table_(comparator_, &arena_) {
+}
+
+MemTableCuda::~MemTableCuda() { assert(refs_ == 0); }
+
+size_t MemTableCuda::ApproximateMemoryUsage() { return arena_.MemoryUsage(); }
+
+__device__ int MemTableCuda::KeyComparator::operator()(const char* aptr,
+ const char* bptr) const {
+ // Internal keys are encoded as length-prefixed strings.
+ Slice a = GetLengthPrefixedSliceCuda(aptr);
+ Slice b = GetLengthPrefixedSliceCuda(bptr);
+ return comparator.Compare(a, b);
+}
+
+class MemTableCudaIterator : public Iterator {
+ public:
+ explicit MemTableCudaIterator(MemTableCuda::Table* table) : iter_(table) {}
+
+ MemTableCudaIterator(const MemTableCudaIterator&) = delete;
+ MemTableCudaIterator& operator=(const MemTableCudaIterator&) = delete;
+
+ ~MemTableCudaIterator() override = default;
+
+ bool Valid() const override { return iter_.Valid(); }
+ void Seek(const Slice& k) override {
+ char * key = nullptr;
+ EncodeKey(&tmp_, k);
+ Seek_<<<1, 1>>>(this, key);
+ }
+ void SeekToFirst() override { iter_.SeekToFirst(); }
+ void SeekToLast() override { iter_.SeekToLast(); }
+ void Next() override { iter_.Next(); }
+ void Prev() override { Prev_<<<1, 1>>>(this); }
+ Slice key() const override { return GetLengthPrefixedSlice(iter_.key()); }
+ Slice value() const override {
+ Slice key_slice = GetLengthPrefixedSlice(iter_.key());
+ return GetLengthPrefixedSlice(key_slice.data() + key_slice.size());
+ }
+
+ Status status() const override { return Status::OK(); }
+
+ static __global__ void Prev_(MemTableCudaIterator* this_) {
+ this_->iter_.Prev();
+ }
+
+ static __global__ void Seek_(MemTableCudaIterator* this_, const char * key) {
+ this_->iter_.Seek(key);
+ }
+
+ private:
+ MemTableCuda::Table::Iterator iter_;
+ std::string tmp_; // For passing to EncodeKey
+};
+
+Iterator* MemTableCuda::NewIterator() { return new MemTableCudaIterator(&table_); }
+
+
+__global__ void MemTableCudaAdd_(MemTableCuda * mtb, size_t encoded_len, char * encode_data) {
+ char* buf = mtb->arena_.Allocate(encoded_len);
+ memcpy(buf, encode_data, encoded_len);
+ mtb->table_.Insert(buf);
+}
+
+
+void MemTableCuda::Add(SequenceNumber s, ValueType type, const Slice& key,
+ const Slice& value) {
+ // Format of an entry is concatenation of:
+ // key_size : varint32 of internal_key.size()
+ // key bytes : char[internal_key.size()]
+ // tag : uint64((sequence << 8) | type)
+ // value_size : varint32 of value.size()
+ // value point : point to host memory
+ size_t key_size = key.size();
+ size_t val_size = value.size();
+ size_t internal_key_size = key_size + 8;
+ //const size_t encoded_len = VarintLength(internal_key_size) +
+ // internal_key_size + VarintLength(val_size) +
+ // val_size;
+ const size_t encoded_len = VarintLength(internal_key_size) + internal_key_size +
+ VarintLength(val_size) + 8;
+
+
+ char * key_mem = this->arena_.Allocate( key_size);
+ std::memcpy(key_mem, key.data(), key_size);
+
+ //char * tag_mem = key_mem + key_size;
+ //EncodeFixed64(tag_mem, (s << 8) | type);
+
+ char * val_mem = this->arena_.Allocate(val_size);
+ std::memcpy(val_mem, value.data(), val_size);
+
+ char * insert_val = new char[encoded_len], *cuda_insert = nullptr;
+ // EncodeVarint32(insert_val, encoded_len);
+ char * p = EncodeVarint32(insert_val + 8, internal_key_size);
+ //EncodeFixed64(p, reinterpret_cast<uint64_t>(key_mem));
+ memcpy(p, key_mem, key_size);
+ p += key_size;
+ EncodeFixed64(p, (s << 8) | type );
+ p += 8;
+ //EncodeFixed64(p, reinterpret_cast<uint64_t>(val_size));
+ p = EncodeVarint32(p, val_size);
+ EncodeFixed64(p, reinterpret_cast<uint64_t>(val_mem));
+
+ assert(p + 8 == insert_val + encoded_len);
+
+ cudaMalloc((void**)&cuda_insert, encoded_len);
+ cudaMemcpy(cuda_insert, insert_val, encoded_len, cudaMemcpyHostToDevice);
+
+ MemTableCudaAdd_<<<1, 1>>>(this, encoded_len, cuda_insert);
+ cudaDeviceSynchronize();
+
+ cudaFree(cuda_insert);
+ delete[] insert_val;
+}
+
+__global__ void MemTableCudaGet_(MemTableCuda * met, char * memkey, char ** data, size_t* malloc_size) {
+ *data = nullptr;
+ auto iter = met->getIter();
+ iter.Seek(memkey);
+ if (iter.Valid()) {
+ // entry format is:
+ // klength varint32
+ // userkey char[klength]
+ // tag uint64
+ // vlength varint32
+ // value char[vlength]
+ // Check that it belongs to same user key. We do not check the
+ // sequence number since the Seek() call above should have skipped
+ // all entries with overly large sequence numbers.
+ size_t key_size;
+ const char *entry = iter.key();
+ const char * p = GetVarint32PtrCuda(entry, entry + 5,
+ reinterpret_cast<uint32_t*>(&key_size));
+ p = GetVarint32PtrCuda(p + key_size, p + key_size + 5, nullptr);
+ *malloc_size = (p - entry + 8);
+ cudaMalloc((void**)*data, *malloc_size);
+ memcpy(*data, entry, *malloc_size);
+ }
+}
+
+bool MemTableCuda::Get(const LookupKey& key, std::string* value, Status* s) {
+
+ Slice memkey = key.memtable_key();
+ char * cuda_mem_key = nullptr;
+ cudaMalloc((void**)&cuda_mem_key, memkey.size());
+ cudaMemcpy(cuda_mem_key, memkey.data(), memkey.size(), cudaMemcpyHostToDevice);
+
+ char ** cuda_skiplist_key = nullptr;
+ cudaMalloc((void**)&cuda_skiplist_hkey, sizeof(char *));
+
+ size_t * cuda_malloc_size = nullptr;
+ cudaMalloc((void**)&cuda_malloc_size, sizeof(cuda_malloc_size));
+
+ MemTableCudaGet_<<<1,1>>>(this, cuda_mem_key, cuda_skiplist_key, cuda_malloc_size);
+ cudaDeviceSynchronize();
+ cudaFree(cuda_mem_key);
+
+ if (*cuda_skiplist_key == nullptr) {
+ cudaFree(cuda_skiplist_key);
+ cudaFree(cuda_malloc_size);
+ return false;
+ }
+ auto * malloc_size = new size_t;
+ cudaMemcpy(malloc_size, cuda_malloc_size, sizeof(size_t), cudaMemcpyDeviceToHost);
+
+ char * entry = this->host_arena_.Allocate(*malloc_size);
+ cudaMemcpy(entry, cuda_skiplist_key, *malloc_size, cudaMemcpyDeviceToHost);
+
+ cudaFree(cuda_skiplist_key);
+ cudaFree(cuda_malloc_size);
+ delete malloc_size;
+
+ uint32_t key_length;
+ const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length);
+ if (comparator_.comparator.user_comparator()->Compare(
+ Slice(key_ptr, key_length - 8), key.user_key()) == 0) {
+ // Correct user key
+ const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
+ switch (static_cast<ValueType>(tag & 0xff)) {
+ case kTypeValue: {
+ // HOW CAN WE GET IT?
+ Slice v = GetLengthPrefixedSlice(key_ptr + key_length);
+ value->assign(v.data(), v.size());
+ return true;
+ }
+ case kTypeDeletion:
+ *s = Status::NotFound(Slice());
+ return true;
+ }
+ }
+
+
+ return false;
+}
+
} // namespace leveldb