diff --git a/mindspore/ccsrc/dataset/api/python_bindings.cc b/mindspore/ccsrc/dataset/api/python_bindings.cc index 214ce4c1537ac658b15947ea7f1e72e3c3115454..9865396a7dc1d860ee028769468906d78d1d9bfc 100644 --- a/mindspore/ccsrc/dataset/api/python_bindings.cc +++ b/mindspore/ccsrc/dataset/api/python_bindings.cc @@ -60,6 +60,7 @@ #include "dataset/kernels/data/to_float16_op.h" #include "dataset/util/random.h" #include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_sample.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" @@ -152,9 +153,14 @@ void bindDatasetOps(py::module *m) { }); (void)py::class_>(*m, "MindRecordOp") - .def_static("get_num_rows", [](const std::string &path) { + .def_static("get_num_rows", [](const std::string &path, const py::object &sampler) { int64_t count = 0; - THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, &count)); + std::shared_ptr op; + if (py::hasattr(sampler, "_create_for_minddataset")) { + auto create = sampler.attr("_create_for_minddataset"); + op = create().cast>(); + } + THROW_IF_ERROR(MindRecordOp::CountTotalRows(path, op, &count)); return count; }); @@ -435,6 +441,16 @@ void bindSamplerOps(py::module *m) { (void)py::class_>( *m, "MindrecordSubsetRandomSampler") .def(py::init, uint32_t>(), py::arg("indices"), py::arg("seed") = GetSeed()); + (void)py::class_>( + *m, "MindrecordPkSampler") + .def(py::init([](int64_t kVal, bool shuffle) { + if (shuffle == true) { + return std::make_shared("label", kVal, std::numeric_limits::max(), + GetSeed()); + } else { + return std::make_shared("label", kVal); + } + })); (void)py::class_>(*m, "WeightedRandomSampler") .def(py::init, int64_t, bool>(), py::arg("weights"), py::arg("numSamples"), diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc index fbb772af593ebcb5c7d4795b17ffa18dfa676e09..72dee6f2e6e0f929bceec42939e66f4c68706dc2 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.cc @@ -655,9 +655,10 @@ Status MindRecordOp::LaunchThreadAndInitOp() { return Status::OK(); } -Status MindRecordOp::CountTotalRows(const std::string dataset_path, int64_t *count) { +Status MindRecordOp::CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count) { std::unique_ptr shard_reader = std::make_unique(); - MSRStatus rc = shard_reader->CountTotalRows(dataset_path, count); + MSRStatus rc = shard_reader->CountTotalRows(dataset_path, op, count); if (rc == MSRStatus::FAILED) { RETURN_STATUS_UNEXPECTED("MindRecordOp count total rows failed."); } diff --git a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h index aca5c86c2cef0baa2399cb60f98cc0ac90397e64..899919e5290469f58bee98185c5f6ad20273651e 100644 --- a/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h +++ b/mindspore/ccsrc/dataset/engine/datasetops/source/mindrecord_op.h @@ -171,7 +171,8 @@ class MindRecordOp : public ParallelOp { int32_t num_rows() const { return num_rows_; } // Getter method - static Status CountTotalRows(const std::string dataset_path, int64_t *count); + static Status CountTotalRows(const std::string dataset_path, const std::shared_ptr &op, + int64_t *count); // Getter method int32_t rows_per_buffer() const { return rows_per_buffer_; } diff --git a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h index d31037c8ad39a5d756a096d3cfd7a8a5e6baa9c3..3af4d7f89137bf7536b6f7bb699233de02c2250e 100644 --- a/mindspore/ccsrc/mindrecord/include/common/shard_utils.h +++ b/mindspore/ccsrc/mindrecord/include/common/shard_utils.h @@ -72,6 +72,8 @@ enum ShardType { enum SamplerType { kCustomTopNSampler, kCustomTopPercentSampler, kSubsetRandomSampler, kPKSampler }; +enum ShuffleType { kShuffleCategory, kShuffleSample }; + const double kEpsilon = 1e-7; const int kThreadNumber = 14; diff --git a/mindspore/ccsrc/mindrecord/include/shard_category.h b/mindspore/ccsrc/mindrecord/include/shard_category.h index b8a761154094783bd4d0b8ba319966930b701f63..b2fe18fbacb19e369249b8a34c62d61c411c6eb7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_category.h +++ b/mindspore/ccsrc/mindrecord/include/shard_category.h @@ -17,6 +17,8 @@ #ifndef MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ #define MINDRECORD_INCLUDE_SHARD_CATEGORY_H_ +#include +#include #include #include #include @@ -26,16 +28,34 @@ namespace mindspore { namespace mindrecord { class ShardCategory : public ShardOperator { public: - explicit ShardCategory(const std::vector> &categories); + explicit ShardCategory(const std::vector> &categories, + int64_t num_elements = std::numeric_limits::max(), bool replacement = false); + + ShardCategory(const std::string &category_field, int64_t num_elements, + int64_t num_categories = std::numeric_limits::max(), bool replacement = false); ~ShardCategory() override{}; - const std::vector> &get_categories() const; + const std::vector> &get_categories() const { return categories_; } + + const std::string GetCategoryField() const { return category_field_; } + + int64_t GetNumElements() const { return num_elements_; } + + int64_t GetNumCategories() const { return num_categories_; } + + bool GetReplacement() const { return replacement_; } MSRStatus execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: std::vector> categories_; + std::string category_field_; + int64_t num_elements_; + int64_t num_categories_; + bool replacement_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_operator.h b/mindspore/ccsrc/mindrecord/include/shard_operator.h index 9f302e5321df7b215a5bab0af0cbf089eb531283..7476660a703693cd8253d6d5b6836b144be20745 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_operator.h +++ b/mindspore/ccsrc/mindrecord/include/shard_operator.h @@ -43,6 +43,8 @@ class ShardOperator { virtual MSRStatus execute(ShardTask &tasks) = 0; virtual MSRStatus suf_execute(ShardTask &tasks) { return SUCCESS; } + + virtual int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) { return -1; } }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h new file mode 100644 index 0000000000000000000000000000000000000000..df3888dad4f0adb1927faa78efe9e87e0e241eb2 --- /dev/null +++ b/mindspore/ccsrc/mindrecord/include/shard_pk_sample.h @@ -0,0 +1,49 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ +#define MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ + +#include +#include +#include +#include +#include "mindrecord/include/shard_operator.h" +#include "mindrecord/include/shard_shuffle.h" +#include "mindrecord/include/shard_category.h" + +namespace mindspore { +namespace mindrecord { +class ShardPkSample : public ShardCategory { + public: + ShardPkSample(const std::string &category_field, int64_t num_elements); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories); + + ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, uint32_t seed); + + ~ShardPkSample() override{}; + + MSRStatus suf_execute(ShardTask &tasks) override; + + private: + bool shuffle_; + std::shared_ptr shuffle_op_; +}; +} // namespace mindrecord +} // namespace mindspore + +#endif // MINDRECORD_INCLUDE_SHARD_PK_SAMPLE_H_ diff --git a/mindspore/ccsrc/mindrecord/include/shard_reader.h b/mindspore/ccsrc/mindrecord/include/shard_reader.h index 5548473cd78a50dae340258935cb95cca18786db..3263b2006d735dd662177641755dea355502f0fc 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_reader.h +++ b/mindspore/ccsrc/mindrecord/include/shard_reader.h @@ -115,9 +115,10 @@ class ShardReader { /// \brief get the number of rows in database /// \param[in] file_path the path of ONE file, any file in dataset is fine + /// \param[in] op smart pointer refer to ShardCategory or ShardSample object /// \param[out] count # of rows /// \return MSRStatus the status of MSRStatus - MSRStatus CountTotalRows(const std::string &file_path, int64_t *count); + MSRStatus CountTotalRows(const std::string &file_path, const std::shared_ptr &op, int64_t *count); /// \brief shuffle task with incremental seed /// \return void @@ -197,6 +198,9 @@ class ShardReader { /// \brief get NLP flag bool get_nlp_flag(); + /// \brief get all classes + MSRStatus GetAllClasses(const std::string &category_field, std::set &categories); + protected: /// \brief sqlite call back function static int SelectCallback(void *p_data, int num_fields, char **p_fields, char **p_col_names); @@ -249,8 +253,8 @@ class ShardReader { const std::vector> &operators); /// \brief create category-applied task list - int CreateTasksByCategory(const std::vector> &row_group_summary, - const std::vector> &operators); + MSRStatus CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op); /// \brief create task list in row-reader mode MSRStatus CreateTasksByRow(const std::vector> &row_group_summary, @@ -284,6 +288,12 @@ class ShardReader { MSRStatus ReadBlob(const int &shard_id, const uint64_t &page_offset, const int &page_length, const int &buf_id); + /// \brief get classes in one shard + void GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, std::set &categories); + + /// \brief get number of classes + int64_t GetNumClasses(const std::string &file_path, const std::string &category_field); + protected: uint64_t header_size_; // header size uint64_t page_size_; // page size diff --git a/mindspore/ccsrc/mindrecord/include/shard_sample.h b/mindspore/ccsrc/mindrecord/include/shard_sample.h index 15353fd0ff5d34f9dacbe248f3668470f40caa87..b16fc5cc4f50b1f70464399a0d725f0a67a031a7 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_sample.h +++ b/mindspore/ccsrc/mindrecord/include/shard_sample.h @@ -41,8 +41,11 @@ class ShardSample : public ShardOperator { const std::pair get_partitions() const; MSRStatus execute(ShardTask &tasks) override; + MSRStatus suf_execute(ShardTask &tasks) override; + int64_t GetNumSamples(int64_t dataset_size, int64_t num_classes) override; + private: int numerator_; int denominator_; diff --git a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h index 464881aa7a0a06234cb19457b2a0c620e281107e..027a5ad527bc085619b5906beb114bd47a78994e 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_shuffle.h +++ b/mindspore/ccsrc/mindrecord/include/shard_shuffle.h @@ -24,7 +24,7 @@ namespace mindspore { namespace mindrecord { class ShardShuffle : public ShardOperator { public: - explicit ShardShuffle(uint32_t seed = 0); + explicit ShardShuffle(uint32_t seed = 0, ShuffleType shuffle_type = kShuffleCategory); ~ShardShuffle() override{}; @@ -32,6 +32,7 @@ class ShardShuffle : public ShardOperator { private: uint32_t shuffle_seed_; + ShuffleType shuffle_type_; }; } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/include/shard_task.h b/mindspore/ccsrc/mindrecord/include/shard_task.h index 30ea352ef39478b3f3d541382f814bd3cf1afbdf..b276b5150f1bf2258e97b1c005a0117b60337f86 100644 --- a/mindspore/ccsrc/mindrecord/include/shard_task.h +++ b/mindspore/ccsrc/mindrecord/include/shard_task.h @@ -41,7 +41,9 @@ class ShardTask { std::tuple, std::vector, json> &get_task_by_id(size_t id); - static ShardTask Combine(std::vector &category_tasks); + std::tuple, std::vector, json> &get_random_task(); + + static ShardTask Combine(std::vector &category_tasks, bool replacement, int64_t num_elements); uint32_t categories = 1; diff --git a/mindspore/ccsrc/mindrecord/io/shard_reader.cc b/mindspore/ccsrc/mindrecord/io/shard_reader.cc index fd3fede5a2d65f28bae6c5c1fb27d99a658fc588..9cd02d9120f2a5bb8fa4343ff06cfdc54bf974bd 100644 --- a/mindspore/ccsrc/mindrecord/io/shard_reader.cc +++ b/mindspore/ccsrc/mindrecord/io/shard_reader.cc @@ -315,6 +315,43 @@ MSRStatus ShardReader::ReadAllRowsInShard(int shard_id, const std::string &sql, return ConvertLabelToJson(labels, fs, offsets, shard_id, columns, column_values); } +MSRStatus ShardReader::GetAllClasses(const std::string &category_field, std::set &categories) { + auto ret = ShardIndexGenerator::GenerateFieldName(std::make_pair(column_schema_id_[category_field], category_field)); + if (SUCCESS != ret.first) { + return FAILED; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count_); + for (int x = 0; x < shard_count_; x++) { + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, database_paths_[x], x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count_; x++) { + threads[x].join(); + } + return SUCCESS; +} + +void ShardReader::GetClassesInShard(sqlite3 *db, int shard_id, const std::string sql, + std::set &categories) { + if (nullptr == db) { + return; + } + std::vector> columns; + char *errmsg = nullptr; + int ret = sqlite3_exec(db, common::SafeCStr(sql), SelectCallback, &columns, &errmsg); + if (ret != SQLITE_OK) { + sqlite3_free(errmsg); + sqlite3_close(db); + MS_LOG(ERROR) << "Error in select sql statement, sql:" << common::SafeCStr(sql) << ", error: " << errmsg; + return; + } + MS_LOG(INFO) << "Get" << static_cast(columns.size()) << " records from shard " << shard_id << " index."; + for (int i = 0; i < static_cast(columns.size()); ++i) { + categories.emplace(columns[i][0]); + } +} + ROW_GROUPS ShardReader::ReadAllRowGroup(std::vector &columns) { std::string fields = "ROW_GROUP_ID, PAGE_OFFSET_BLOB, PAGE_OFFSET_BLOB_END"; std::vector>> offsets(shard_count_, std::vector>{}); @@ -667,11 +704,64 @@ MSRStatus ShardReader::Finish() { return SUCCESS; } -MSRStatus ShardReader::CountTotalRows(const std::string &file_path, int64_t *count) { +int64_t ShardReader::GetNumClasses(const std::string &file_path, const std::string &category_field) { + ShardHeader sh = ShardHeader(); + if (sh.Build(file_path) == FAILED) { + return -1; + } + auto header = std::make_shared(sh); + auto file_paths = header->get_shard_addresses(); + auto shard_count = file_paths.size(); + auto index_fields = header->get_fields(); + + std::map map_schema_id_fields; + for (auto &field : index_fields) { + map_schema_id_fields[field.second] = field.first; + } + auto ret = + ShardIndexGenerator::GenerateFieldName(std::make_pair(map_schema_id_fields[category_field], category_field)); + if (SUCCESS != ret.first) { + return -1; + } + std::string sql = "SELECT DISTINCT " + ret.second + " FROM INDEXES"; + std::vector threads = std::vector(shard_count); + std::set categories; + for (int x = 0; x < shard_count; x++) { + sqlite3 *db = nullptr; + int rc = sqlite3_open_v2(common::SafeCStr(file_paths[x] + ".db"), &db, SQLITE_OPEN_READONLY, nullptr); + if (SQLITE_OK != rc) { + MS_LOG(ERROR) << "Can't open database, error: " << sqlite3_errmsg(db); + return -1; + } + threads[x] = std::thread(&ShardReader::GetClassesInShard, this, db, x, sql, std::ref(categories)); + } + + for (int x = 0; x < shard_count; x++) { + threads[x].join(); + } + return categories.size(); +} + +MSRStatus ShardReader::CountTotalRows(const std::string &file_path, const std::shared_ptr &op, + int64_t *count) { if (Init(file_path) == FAILED) { return FAILED; } - *count = num_rows_; + int64_t num_samples = num_rows_; + if (std::dynamic_pointer_cast(op)) { + auto category_op = std::dynamic_pointer_cast(op); + std::string category_field = category_op->GetCategoryField(); + auto num_classes = GetNumClasses(file_path, category_field); + num_samples = category_op->GetNumSamples(num_rows_, num_classes); + } else if (std::dynamic_pointer_cast(op)) { + num_samples = op->GetNumSamples(num_rows_, 0); + } else { + } + if (-1 == num_samples) { + MS_LOG(ERROR) << "Failed to get dataset size."; + return FAILED; + } + *count = num_samples; return SUCCESS; } @@ -793,6 +883,8 @@ MSRStatus ShardReader::Launch(bool isSimpleReader) { thread_set_[x] = std::thread(&ShardReader::ConsumerByRow, this, x); } } + + MS_LOG(INFO) << "Launch read thread successfully."; return SUCCESS; } @@ -828,44 +920,67 @@ MSRStatus ShardReader::CreateTasksByBlock(const std::vector> &row_group_summary, - const std::vector> &operators) { +MSRStatus ShardReader::CreateTasksByCategory(const std::vector> &row_group_summary, + const std::shared_ptr &op) { vector columns = GetAllColumns(); CheckIfColumnInIndex(columns); - int category_operator = -1; - for (uint32_t i = 0; i < operators.size(); ++i) { - const auto &op = operators[i]; - if (std::dynamic_pointer_cast(op)) category_operator = static_cast(i); + auto category_op = std::dynamic_pointer_cast(op); + auto categories = category_op->get_categories(); + int64_t num_elements = category_op->GetNumElements(); + if (num_elements <= 0) { + MS_LOG(ERROR) << "Parameter num_element is not positive"; + return FAILED; + } + if (categories.empty() == true) { + std::string category_field = category_op->GetCategoryField(); + int64_t num_categories = category_op->GetNumCategories(); + if (num_categories <= 0) { + MS_LOG(ERROR) << "Parameter num_categories is not positive"; + return FAILED; + } + std::set categories_set; + auto ret = GetAllClasses(category_field, categories_set); + if (SUCCESS != ret) { + return FAILED; + } + int i = 0; + for (auto it = categories_set.begin(); it != categories_set.end() && i < num_categories; ++it) { + categories.emplace_back(category_field, *it); + i++; + } } - - if (category_operator == -1) return category_operator; - - auto categories = std::dynamic_pointer_cast(operators[category_operator])->get_categories(); - // Generate task list, a task will create a batch std::vector categoryTasks(categories.size()); for (uint32_t categoryNo = 0; categoryNo < categories.size(); ++categoryNo) { + int category_index = 0; for (const auto &rg : row_group_summary) { + if (category_index >= num_elements) break; auto shard_id = std::get<0>(rg); auto group_id = std::get<1>(rg); auto details = ReadRowGroupCriteria(group_id, shard_id, categories[categoryNo], columns); if (SUCCESS != std::get<0>(details)) { - return -2; + return FAILED; } auto offsets = std::get<4>(details); auto number_of_rows = offsets.size(); for (uint32_t iStart = 0; iStart < number_of_rows; iStart += 1) { - categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], - std::get<5>(details)[iStart]); + if (category_index < num_elements) { + categoryTasks[categoryNo].InsertTask(shard_id, group_id, std::get<4>(details)[iStart], + std::get<5>(details)[iStart]); + category_index++; + } } } MS_LOG(INFO) << "Category #" << categoryNo << " has " << categoryTasks[categoryNo].Size() << " tasks"; } - tasks_ = ShardTask::Combine(categoryTasks); - return category_operator; + tasks_ = ShardTask::Combine(categoryTasks, category_op->GetReplacement(), num_elements); + if (SUCCESS != (*category_op)(tasks_)) { + return FAILED; + } + return SUCCESS; } MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, @@ -896,14 +1011,26 @@ MSRStatus ShardReader::CreateTasksByRow(const std::vector> &row_group_summary, const std::vector> &operators) { if (block_reader_) { - CreateTasksByBlock(row_group_summary, operators); + if (SUCCESS != CreateTasksByBlock(row_group_summary, operators)) { + return FAILED; + } } else { - int category_operator = CreateTasksByCategory(row_group_summary, operators); - if (category_operator == -1) { - CreateTasksByRow(row_group_summary, operators); + int category_operator = -1; + for (uint32_t i = 0; i < operators.size(); ++i) { + const auto &op = operators[i]; + if (std::dynamic_pointer_cast(op)) { + category_operator = static_cast(i); + break; + } } - if (category_operator == -2) { - return FAILED; + if (-1 == category_operator) { + if (SUCCESS != CreateTasksByRow(row_group_summary, operators)) { + return FAILED; + } + } else { + if (SUCCESS != CreateTasksByCategory(row_group_summary, operators[category_operator])) { + return FAILED; + } } } diff --git a/mindspore/ccsrc/mindrecord/meta/shard_category.cc b/mindspore/ccsrc/mindrecord/meta/shard_category.cc index 859a3b343fecc21ca24a0f8bab8e1362462c5e50..80816e7a79de95d4155d73b27a6542b653a61850 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_category.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_category.cc @@ -18,11 +18,30 @@ namespace mindspore { namespace mindrecord { -ShardCategory::ShardCategory(const std::vector> &categories) - : categories_(categories) {} +ShardCategory::ShardCategory(const std::vector> &categories, int64_t num_elements, + bool replacement) + : categories_(categories), + category_field_(""), + num_elements_(num_elements), + num_categories_(0), + replacement_(replacement) {} -const std::vector> &ShardCategory::get_categories() const { return categories_; } +ShardCategory::ShardCategory(const std::string &category_field, int64_t num_elements, int64_t num_categories, + bool replacement) + : categories_({}), + category_field_(category_field), + num_elements_(num_elements), + num_categories_(num_categories), + replacement_(replacement) {} MSRStatus ShardCategory::execute(ShardTask &tasks) { return SUCCESS; } + +int64_t ShardCategory::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (dataset_size == 0) return dataset_size; + if (dataset_size > 0 && num_categories_ > 0 && num_elements_ > 0) { + return std::min(num_categories_, num_classes) * num_elements_; + } + return -1; +} } // namespace mindrecord } // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e2e892e63b9d39e8f80bb5694962c56811b344f --- /dev/null +++ b/mindspore/ccsrc/mindrecord/meta/shard_pk_sample.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "mindrecord/include/shard_pk_sample.h" + +using mindspore::LogStream; +using mindspore::ExceptionType::NoExceptionType; +using mindspore::MsLogLevel::ERROR; + +namespace mindspore { +namespace mindrecord { +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements) + : ShardCategory(category_field, num_elements, std::numeric_limits::max(), true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(false) {} + +ShardPkSample::ShardPkSample(const std::string &category_field, int64_t num_elements, int64_t num_categories, + uint32_t seed) + : ShardCategory(category_field, num_elements, num_categories, true), shuffle_(true) { + shuffle_op_ = std::make_shared(seed, kShuffleSample); // do shuffle and replacement +} + +MSRStatus ShardPkSample::suf_execute(ShardTask &tasks) { + if (shuffle_ == true) { + if (SUCCESS != (*shuffle_op_)(tasks)) { + return FAILED; + } + } + return SUCCESS; +} +} // namespace mindrecord +} // namespace mindspore diff --git a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc index ef627b0c09f98aed0b6fdf36acba4806facc5c14..a9cfce0d01524d401ca087a9dfb18dccfd27f5f8 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_sample.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_sample.cc @@ -56,6 +56,24 @@ ShardSample::ShardSample(const std::vector &indices, uint32_t seed) shuffle_op_ = std::make_shared(seed); } +int64_t ShardSample::GetNumSamples(int64_t dataset_size, int64_t num_classes) { + if (sampler_type_ == kCustomTopNSampler) { + return no_of_samples_; + } + + if (sampler_type_ == kCustomTopPercentSampler) { + if (dataset_size % denominator_ == 0) { + return dataset_size / denominator_ * numerator_; + } else { + return dataset_size / denominator_ * numerator_ + 1; + } + } + if (sampler_type_ == kSubsetRandomSampler) { + return indices_.size(); + } + return -1; +} + const std::pair ShardSample::get_partitions() const { if (numerator_ == 1 && denominator_ > 1) { return std::pair(denominator_, partition_id_); diff --git a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc index f8ad2c341dc84ae0c6558354b8130f4ba382141f..757dcb7b74d54537b0b01c315cc232b7e2fcf48b 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_shuffle.cc @@ -20,25 +20,33 @@ namespace mindspore { namespace mindrecord { -ShardShuffle::ShardShuffle(uint32_t seed) : shuffle_seed_(seed) {} +ShardShuffle::ShardShuffle(uint32_t seed, ShuffleType shuffle_type) + : shuffle_seed_(seed), shuffle_type_(shuffle_type) {} MSRStatus ShardShuffle::execute(ShardTask &tasks) { if (tasks.categories < 1) { return FAILED; } - uint32_t individual_size = tasks.Size() / tasks.categories; - std::vector> new_permutations(tasks.categories, std::vector(individual_size)); - for (uint32_t i = 0; i < tasks.categories; i++) { - for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); - std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); - } - shuffle_seed_++; - tasks.permutation_.clear(); - for (uint32_t j = 0; j < individual_size; j++) { + if (shuffle_type_ == kShuffleSample) { + if (tasks.permutation_.empty() == true) { + tasks.MakePerm(); + } + std::shuffle(tasks.permutation_.begin(), tasks.permutation_.end(), std::default_random_engine(shuffle_seed_)); + } else { // shuffle unit like: (a1, b1, c1),(a2, b2, c2),..., (an, bn, cn) + uint32_t individual_size = tasks.Size() / tasks.categories; + std::vector> new_permutations(tasks.categories, std::vector(individual_size)); for (uint32_t i = 0; i < tasks.categories; i++) { - tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + for (uint32_t j = 0; j < individual_size; j++) new_permutations[i][j] = static_cast(j); + std::shuffle(new_permutations[i].begin(), new_permutations[i].end(), std::default_random_engine(shuffle_seed_)); + } + tasks.permutation_.clear(); + for (uint32_t j = 0; j < individual_size; j++) { + for (uint32_t i = 0; i < tasks.categories; i++) { + tasks.permutation_.push_back(new_permutations[i][j] * static_cast(tasks.categories) + static_cast(i)); + } } } + shuffle_seed_++; return SUCCESS; } } // namespace mindrecord diff --git a/mindspore/ccsrc/mindrecord/meta/shard_task.cc b/mindspore/ccsrc/mindrecord/meta/shard_task.cc index 3744d881a4d491e7ae73ab9127eaaed2b52deb71..be566d160127d07b358be52bb42258ae079b14e7 100644 --- a/mindspore/ccsrc/mindrecord/meta/shard_task.cc +++ b/mindspore/ccsrc/mindrecord/meta/shard_task.cc @@ -35,8 +35,6 @@ void ShardTask::InsertTask(int shard_id, int group_id, const std::vector, std::vector, json> task) { @@ -44,9 +42,6 @@ void ShardTask::InsertTask(std::tuple, std::vector(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() << ", size of task_list_: " << task_list_.size() << "."; task_list_.push_back(std::move(task)); - MS_LOG(DEBUG) << "Out of insert task, shard_id: " << std::get<0>(std::get<0>(task)) - << ", group_id: " << std::get<1>(std::get<0>(task)) << ", label: " << std::get<2>(task).dump() - << ", size of task_list_: " << task_list_.size() << "."; } void ShardTask::PopBack() { task_list_.pop_back(); } @@ -69,18 +64,39 @@ std::tuple, std::vector, json> &ShardTask::get_ta return task_list_[id]; } -ShardTask ShardTask::Combine(std::vector &category_tasks) { +std::tuple, std::vector, json> &ShardTask::get_random_task() { + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(0, task_list_.size() - 1); + return task_list_[dis(gen)]; +} +ShardTask ShardTask::Combine(std::vector &category_tasks, bool replacement, int64_t num_elements) { ShardTask res; if (category_tasks.empty()) return res; auto total_categories = category_tasks.size(); res.categories = static_cast(total_categories); - auto minTasks = category_tasks[0].Size(); - for (uint32_t i = 1; i < total_categories; i++) { - minTasks = std::min(minTasks, category_tasks[i].Size()); - } - for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + if (replacement == false) { + auto minTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + minTasks = std::min(minTasks, category_tasks[i].Size()); + } + for (uint32_t task_no = 0; task_no < minTasks; task_no++) { + for (uint32_t i = 0; i < total_categories; i++) { + res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + } + } + } else { + auto maxTasks = category_tasks[0].Size(); + for (uint32_t i = 1; i < total_categories; i++) { + maxTasks = std::max(maxTasks, category_tasks[i].Size()); + } + if (num_elements != std::numeric_limits::max()) { + maxTasks = static_cast(num_elements); + } for (uint32_t i = 0; i < total_categories; i++) { - res.InsertTask(std::move(category_tasks[i].get_task_by_id(static_cast(task_no)))); + for (uint32_t j = 0; j < maxTasks; j++) { + res.InsertTask(category_tasks[i].get_random_task()); + } } } return res; diff --git a/mindspore/dataset/engine/datasets.py b/mindspore/dataset/engine/datasets.py index 5b3c0f1503842f2c81d08d489c15e825a4245d36..28697a6c434811c376bda6e0b71fd582f74f463b 100644 --- a/mindspore/dataset/engine/datasets.py +++ b/mindspore/dataset/engine/datasets.py @@ -1882,7 +1882,8 @@ class MindDataset(SourceDataset): block_reader (bool, optional): Whether read data by block mode (default=False). sampler (Sampler, optional): Object used to choose samples from the dataset (default=None, sampler is exclusive - with shuffle and block_reader). Support list: SubsetRandomSampler. + with shuffle and block_reader). Support list: SubsetRandomSampler, + PkSampler Raises: ValueError: If num_shards is specified but shard_id is None. @@ -1915,8 +1916,10 @@ class MindDataset(SourceDataset): if block_reader is True: logger.warning("WARN: global shuffle is not used.") - if sampler is not None and isinstance(sampler, samplers.SubsetRandomSampler) is False: - raise ValueError("the sampler is not supported yet.") + if sampler is not None: + if isinstance(sampler, samplers.SubsetRandomSampler) is False and \ + isinstance(sampler, samplers.PKSampler) is False: + raise ValueError("the sampler is not supported yet.") # sampler exclusive if block_reader is True and sampler is not None: @@ -1952,7 +1955,7 @@ class MindDataset(SourceDataset): Number, number of batches. """ - num_rows = MindRecordOp.get_num_rows(self.dataset_file) + num_rows = MindRecordOp.get_num_rows(self.dataset_file, self.sampler) if self.partitions is not None and self.partitions[0] > 0: if num_rows % self.partitions[0] == 0: num_rows = num_rows // self.partitions[0] diff --git a/mindspore/dataset/engine/samplers.py b/mindspore/dataset/engine/samplers.py index 421a03ab8de45e44646e2fc447f4cc348bfd1665..82759989cbed23f0bfbc531463e8885ee635f7eb 100644 --- a/mindspore/dataset/engine/samplers.py +++ b/mindspore/dataset/engine/samplers.py @@ -184,6 +184,8 @@ class PKSampler(BuiltinSampler): def create(self): return cde.PKSampler(self.num_val, self.shuffle) + def _create_for_minddataset(self): + return cde.MindrecordPkSampler(self.num_val, self.shuffle) class RandomSampler(BuiltinSampler): """ diff --git a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc index 549e2140f41b1ce1da3cee33000f94b1a51eda6a..bfd49069b20cd9c710db90a4f618d5934173dbba 100644 --- a/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc +++ b/tests/ut/cpp/mindrecord/ut_shard_operator_test.cc @@ -25,6 +25,7 @@ #include "gtest/gtest.h" #include "utils/log_adapter.h" #include "mindrecord/include/shard_category.h" +#include "mindrecord/include/shard_pk_sample.h" #include "mindrecord/include/shard_reader.h" #include "mindrecord/include/shard_sample.h" #include "mindrecord/include/shard_shuffle.h" @@ -146,6 +147,57 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { ASSERT_TRUE(i <= 10); } +TEST_F(TestShardOperator, TestShardPkSamplerBasic) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 20); +} // namespace mindrecord + +TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { + MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test pk sampler")); + + std::string file_name = "./imagenet.shard01"; + auto column_list = std::vector{"file_name", "label"}; + + std::vector> ops; + ops.push_back(std::make_shared("label", 2, 3, 0)); + + ShardReader dataset; + dataset.Open(file_name, 4, column_list, ops); + dataset.Launch(); + + int i = 0; + while (true) { + auto x = dataset.GetNext(); + if (x.empty()) break; + + std::cout << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]) + << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; + i++; + } + dataset.Finish(); + ASSERT_TRUE(i == 6); +} // namespace mindrecord + TEST_F(TestShardOperator, TestShardCategory) { MS_LOG(INFO) << common::SafeCStr(FormatInfo("Test read imageNet")); diff --git a/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt new file mode 100644 index 0000000000000000000000000000000000000000..fbfbba025fb90f9e1ba09f6b556e1ff4e3076a65 --- /dev/null +++ b/tests/ut/data/mindrecord/testImageNetData/annotation_sampler.txt @@ -0,0 +1,10 @@ +image_00001.jpg,164 +image_00002.jpg,164 +image_00003.jpg,164 +image_00004.jpg,599 +image_00005.jpg,599 +image_00006.jpg,599 +image_00007.jpg,13 +image_00008.jpg,13 +image_00009.jpg,13 +image_00010.jpg,13 diff --git a/tests/ut/python/dataset/test_minddataset_sampler.py b/tests/ut/python/dataset/test_minddataset_sampler.py index 3cad3877efa22841250bd8f05eeb81276d01df08..584bb8804137f4fc813bac4b9667bd904a8e5d95 100644 --- a/tests/ut/python/dataset/test_minddataset_sampler.py +++ b/tests/ut/python/dataset/test_minddataset_sampler.py @@ -46,7 +46,7 @@ def add_and_remove_cv_file(): if os.path.exists("{}.db".format(x)): os.remove("{}.db".format(x)) writer = FileWriter(CV_FILE_NAME, FILES_NUM) - data = get_data(CV_DIR_NAME) + data = get_data(CV_DIR_NAME, True) cv_schema_json = {"id": {"type": "int32"}, "file_name": {"type": "string"}, "label": {"type": "int32"}, @@ -61,6 +61,59 @@ def add_and_remove_cv_file(): os.remove("{}.db".format(x)) +def test_cv_minddataset_pk_sample_basic(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(2) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 6 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_shuffle(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(3, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + + assert data_set.get_dataset_size() == 9 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + +def test_cv_minddataset_pk_sample_out_of_range(add_and_remove_cv_file): + """tutorial for cv minderdataset.""" + columns_list = ["data", "file_name", "label"] + num_readers = 4 + sampler = ds.PKSampler(5, None, True) + data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, + sampler=sampler) + assert data_set.get_dataset_size() == 15 + num_iter = 0 + for item in data_set.create_dict_iterator(): + logger.info("-------------- cv reader basic: {} ------------------------".format(num_iter)) + logger.info("-------------- item[file_name]: \ + {}------------------------".format("".join([chr(x) for x in item["file_name"]]))) + logger.info("-------------- item[label]: {} ----------------------------".format(item["label"])) + num_iter += 1 + + def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] @@ -69,8 +122,7 @@ def test_cv_minddataset_subset_random_sample_basic(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -93,8 +145,7 @@ def test_cv_minddataset_subset_random_sample_replica(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 6 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -117,8 +168,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 0 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -133,7 +183,7 @@ def test_cv_minddataset_subset_random_sample_empty(add_and_remove_cv_file): assert num_iter == 0 -def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): +def test_cv_minddataset_subset_random_sample_out_of_range(add_and_remove_cv_file): """tutorial for cv minderdataset.""" columns_list = ["data", "file_name", "label"] num_readers = 4 @@ -141,8 +191,7 @@ def test_cv_minddataset_subset_random_sample_out_range(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -165,8 +214,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): sampler = ds.SubsetRandomSampler(indices) data_set = ds.MindDataset(CV_FILE_NAME + "0", columns_list, num_readers, sampler=sampler) - data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): logger.info( @@ -181,7 +229,7 @@ def test_cv_minddataset_subset_random_sample_negative(add_and_remove_cv_file): assert num_iter == 5 -def get_data(dir_name): +def get_data(dir_name, sampler=False): """ usage: get data from imagenet dataset params: @@ -191,7 +239,10 @@ def get_data(dir_name): if not os.path.isdir(dir_name): raise IOError("Directory {} not exists".format(dir_name)) img_dir = os.path.join(dir_name, "images") - ann_file = os.path.join(dir_name, "annotation.txt") + if sampler: + ann_file = os.path.join(dir_name, "annotation_sampler.txt") + else: + ann_file = os.path.join(dir_name, "annotation.txt") with open(ann_file, "r") as file_reader: lines = file_reader.readlines() diff --git a/tests/ut/python/dataset/test_serdes_dataset.py b/tests/ut/python/dataset/test_serdes_dataset.py index 7fdb0f1dde51fc5a1d4dc761d429c21f8355702d..0a6f86974b02a006283435375d27d7f32070c0bb 100644 --- a/tests/ut/python/dataset/test_serdes_dataset.py +++ b/tests/ut/python/dataset/test_serdes_dataset.py @@ -243,7 +243,7 @@ def test_minddataset(add_and_remove_cv_file): assert ds1_json == ds2_json data = get_data(CV_DIR_NAME) - assert data_set.get_dataset_size() == 10 + assert data_set.get_dataset_size() == 5 num_iter = 0 for item in data_set.create_dict_iterator(): num_iter += 1