diff options
Diffstat (limited to 'db/memtable.cu')
-rw-r--r-- | db/memtable.cu | 246 |
1 files changed, 245 insertions, 1 deletions
diff --git a/db/memtable.cu b/db/memtable.cu index 241a35b..51ca578 100644 --- a/db/memtable.cu +++ b/db/memtable.cu @@ -20,7 +20,6 @@ static Slice GetLengthPrefixedSlice(const char* data) { MemTable::MemTable(const InternalKeyComparator& comparator) : comparator_(comparator), refs_(0) { - } MemTable::~MemTable() { assert(refs_ == 0); } @@ -266,4 +265,249 @@ bool MemTable::Get(const LookupKey& key, std::string* value, Status* s) { return false; } +__device__ const char* GetVarint32PtrFallbackCuda(const char* p, const char* limit, + uint32_t* value) { + uint32_t result = 0; + for (uint32_t shift = 0; shift <= 28 && p < limit; shift += 7) { + uint32_t byte = *(reinterpret_cast<const uint8_t*>(p)); + p++; + if (byte & 128) { + // More bytes are present + result |= ((byte & 127) << shift); + } else { + result |= (byte << shift); + *value = result; + return reinterpret_cast<const char*>(p); + } + } + return nullptr; +} + +__device__ inline const char* GetVarint32PtrCuda(const char* p, const char* limit, + uint32_t* value) { + if (p < limit) { + uint32_t result = *(reinterpret_cast<const uint8_t*>(p)); + if ((result & 128) == 0) { + *value = result; + return p + 1; + } + } + return GetVarint32PtrFallbackCuda(p, limit, value); +} + +struct SizedSlice { + const char * ptr; + size_t len; + + __device__ __host__ explicit SizedSlice(const char * ptr, size_t len) { + this->ptr = ptr; + this->len = len; + } +}; + +__device__ static SizedSlice GetLengthPrefixedSliceCuda(const char* data) { + uint32_t len; + const char* p = data; + p = GetVarint32PtrCuda(p, p + 5, &len); // +5: we assume "p" is not corrupted + return SizedSlice(p, len); +} + + +__device__ MemTableCuda::MemTableCuda(const InternalKeyComparator& comparator) + : comparator_(comparator), refs_(0), table_(comparator_, &arena_) { +} + +MemTableCuda::~MemTableCuda() { assert(refs_ == 0); } + +size_t MemTableCuda::ApproximateMemoryUsage() { return arena_.MemoryUsage(); } + +__device__ int MemTableCuda::KeyComparator::operator()(const char* aptr, + const char* bptr) const { + // Internal keys are encoded as length-prefixed strings. + Slice a = GetLengthPrefixedSliceCuda(aptr); + Slice b = GetLengthPrefixedSliceCuda(bptr); + return comparator.Compare(a, b); +} + +class MemTableCudaIterator : public Iterator { + public: + explicit MemTableCudaIterator(MemTableCuda::Table* table) : iter_(table) {} + + MemTableCudaIterator(const MemTableCudaIterator&) = delete; + MemTableCudaIterator& operator=(const MemTableCudaIterator&) = delete; + + ~MemTableCudaIterator() override = default; + + bool Valid() const override { return iter_.Valid(); } + void Seek(const Slice& k) override { + char * key = nullptr; + EncodeKey(&tmp_, k); + Seek_<<<1, 1>>>(this, key); + } + void SeekToFirst() override { iter_.SeekToFirst(); } + void SeekToLast() override { iter_.SeekToLast(); } + void Next() override { iter_.Next(); } + void Prev() override { Prev_<<<1, 1>>>(this); } + Slice key() const override { return GetLengthPrefixedSlice(iter_.key()); } + Slice value() const override { + Slice key_slice = GetLengthPrefixedSlice(iter_.key()); + return GetLengthPrefixedSlice(key_slice.data() + key_slice.size()); + } + + Status status() const override { return Status::OK(); } + + static __global__ void Prev_(MemTableCudaIterator* this_) { + this_->iter_.Prev(); + } + + static __global__ void Seek_(MemTableCudaIterator* this_, const char * key) { + this_->iter_.Seek(key); + } + + private: + MemTableCuda::Table::Iterator iter_; + std::string tmp_; // For passing to EncodeKey +}; + +Iterator* MemTableCuda::NewIterator() { return new MemTableCudaIterator(&table_); } + + +__global__ void MemTableCudaAdd_(MemTableCuda * mtb, size_t encoded_len, char * encode_data) { + char* buf = mtb->arena_.Allocate(encoded_len); + memcpy(buf, encode_data, encoded_len); + mtb->table_.Insert(buf); +} + + +void MemTableCuda::Add(SequenceNumber s, ValueType type, const Slice& key, + const Slice& value) { + // Format of an entry is concatenation of: + // key_size : varint32 of internal_key.size() + // key bytes : char[internal_key.size()] + // tag : uint64((sequence << 8) | type) + // value_size : varint32 of value.size() + // value point : point to host memory + size_t key_size = key.size(); + size_t val_size = value.size(); + size_t internal_key_size = key_size + 8; + //const size_t encoded_len = VarintLength(internal_key_size) + + // internal_key_size + VarintLength(val_size) + + // val_size; + const size_t encoded_len = VarintLength(internal_key_size) + internal_key_size + + VarintLength(val_size) + 8; + + + char * key_mem = this->arena_.Allocate( key_size); + std::memcpy(key_mem, key.data(), key_size); + + //char * tag_mem = key_mem + key_size; + //EncodeFixed64(tag_mem, (s << 8) | type); + + char * val_mem = this->arena_.Allocate(val_size); + std::memcpy(val_mem, value.data(), val_size); + + char * insert_val = new char[encoded_len], *cuda_insert = nullptr; + // EncodeVarint32(insert_val, encoded_len); + char * p = EncodeVarint32(insert_val + 8, internal_key_size); + //EncodeFixed64(p, reinterpret_cast<uint64_t>(key_mem)); + memcpy(p, key_mem, key_size); + p += key_size; + EncodeFixed64(p, (s << 8) | type ); + p += 8; + //EncodeFixed64(p, reinterpret_cast<uint64_t>(val_size)); + p = EncodeVarint32(p, val_size); + EncodeFixed64(p, reinterpret_cast<uint64_t>(val_mem)); + + assert(p + 8 == insert_val + encoded_len); + + cudaMalloc((void**)&cuda_insert, encoded_len); + cudaMemcpy(cuda_insert, insert_val, encoded_len, cudaMemcpyHostToDevice); + + MemTableCudaAdd_<<<1, 1>>>(this, encoded_len, cuda_insert); + cudaDeviceSynchronize(); + + cudaFree(cuda_insert); + delete[] insert_val; +} + +__global__ void MemTableCudaGet_(MemTableCuda * met, char * memkey, char ** data, size_t* malloc_size) { + *data = nullptr; + auto iter = met->getIter(); + iter.Seek(memkey); + if (iter.Valid()) { + // entry format is: + // klength varint32 + // userkey char[klength] + // tag uint64 + // vlength varint32 + // value char[vlength] + // Check that it belongs to same user key. We do not check the + // sequence number since the Seek() call above should have skipped + // all entries with overly large sequence numbers. + size_t key_size; + const char *entry = iter.key(); + const char * p = GetVarint32PtrCuda(entry, entry + 5, + reinterpret_cast<uint32_t*>(&key_size)); + p = GetVarint32PtrCuda(p + key_size, p + key_size + 5, nullptr); + *malloc_size = (p - entry + 8); + cudaMalloc((void**)*data, *malloc_size); + memcpy(*data, entry, *malloc_size); + } +} + +bool MemTableCuda::Get(const LookupKey& key, std::string* value, Status* s) { + + Slice memkey = key.memtable_key(); + char * cuda_mem_key = nullptr; + cudaMalloc((void**)&cuda_mem_key, memkey.size()); + cudaMemcpy(cuda_mem_key, memkey.data(), memkey.size(), cudaMemcpyHostToDevice); + + char ** cuda_skiplist_key = nullptr; + cudaMalloc((void**)&cuda_skiplist_hkey, sizeof(char *)); + + size_t * cuda_malloc_size = nullptr; + cudaMalloc((void**)&cuda_malloc_size, sizeof(cuda_malloc_size)); + + MemTableCudaGet_<<<1,1>>>(this, cuda_mem_key, cuda_skiplist_key, cuda_malloc_size); + cudaDeviceSynchronize(); + cudaFree(cuda_mem_key); + + if (*cuda_skiplist_key == nullptr) { + cudaFree(cuda_skiplist_key); + cudaFree(cuda_malloc_size); + return false; + } + auto * malloc_size = new size_t; + cudaMemcpy(malloc_size, cuda_malloc_size, sizeof(size_t), cudaMemcpyDeviceToHost); + + char * entry = this->host_arena_.Allocate(*malloc_size); + cudaMemcpy(entry, cuda_skiplist_key, *malloc_size, cudaMemcpyDeviceToHost); + + cudaFree(cuda_skiplist_key); + cudaFree(cuda_malloc_size); + delete malloc_size; + + uint32_t key_length; + const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); + if (comparator_.comparator.user_comparator()->Compare( + Slice(key_ptr, key_length - 8), key.user_key()) == 0) { + // Correct user key + const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); + switch (static_cast<ValueType>(tag & 0xff)) { + case kTypeValue: { + // HOW CAN WE GET IT? + Slice v = GetLengthPrefixedSlice(key_ptr + key_length); + value->assign(v.data(), v.size()); + return true; + } + case kTypeDeletion: + *s = Status::NotFound(Slice()); + return true; + } + } + + + return false; +} + } // namespace leveldb |