"Reasoning problem based on common sense knowledge" source code analysis - overall structure problem

Posted by kyme on Sun, 21 Nov 2021 20:07:43 +0100

2021SC@SDUSC

In the source code analysis in the past few weeks, we have learned how drfact preprocesses the corpus and what the drfact model algorithm has done in the previous steps. However, I will not analyze the specific source code in this week's source code analysis. The reason is that when I conducted source code analysis this week and reviewed the previous source code analysis content, I noticed that the drfact model draws lessons from other models to a certain extent, which is especially reflected in its core source code - calling the functions compiled in other models. Therefore, the content embodied in the source code is no longer just focus on the drfact project package, but also need to examine other source codes of the whole OpenCSR project.

In a sense, I made a certain degree of wrong judgment on the core of the source code of the OpenCSR project, which also means that I need to spend more effort on a deeper understanding of the project. Therefore, before I reach a macro level understanding of the overall structure of the source code and the interconnection between various models, I will not further analyze and explore the specific implementation algorithms in the drfact model and the specific detailed descriptions in other data processing codes, but turn around to analyze the modules and function characteristics that can be interconnected between various models. These interoperable modules and function features will be reflected in this blog.

1, DrFact model and DrKit model

To analyze the relationship between many models, especially how these relationships affect our main research object - DrFact model and what role they play, we first start with DrFact model itself. From the source code of the DrFact model itself, it is not difficult to see that the DrKit model is directly related to the DrFact model. In many source files of the DrFact model, it can be seen that the python source file module from the DrKit model is imported, and many functions from the DrKit model are called in the source code, such as:

In convert_ dpr_ In the index.py source file, the search from the DrKit model is imported_ Utils.py and called write_to_checkpoint() function.

from language.labs.drkit import search_utils

  with tf.device("/cpu:0"):
    search_utils.write_to_checkpoint("fact_db_emb", fact_emb, tf.float32, output_index_path)

In fact2fact_ In the index.py source file, the search from the DrKit model is also imported_ Utils.py and called write_ragged_to_checkpoint() function.

from language.labs.drkit import search_utils

  search_utils.write_ragged_to_checkpoint(
      "fact2fact", sp_fact2fact,
      os.path.join(FLAGS.fact2fact_index_dir, "fact2fact.npz"))

In index_ In the core.py source file, search from the DrKit model is also imported_ Utils.py, and write is called multiple times_ to_ Checkpoint() function and write_ragged_to_checkpoint() function. In addition, in the source file, Bert is also imported_ utils_ V2.py and hotpotqa.index.py respectively call the betpredictor() function and get_sub_paras() function.  

from language.labs.drkit import bert_utils_v2
from language.labs.drkit import search_utils
from language.labs.drkit.hotpotqa import index as index_util

# The following is how to import search_ Example after utils invocation
  search_utils.write_to_checkpoint(
      "coref", np.array([m[0] for m in mentions], dtype=np.int32), tf.int32,
      os.path.join(FLAGS.index_result_path, "coref.npz"))

  search_utils.write_ragged_to_checkpoint(
      "ent2ment", sp_entity2mention,
      os.path.join(FLAGS.index_result_path, "ent2ment.npz"))

  search_utils.write_ragged_to_checkpoint(
      "ent2fact", sp_entity2fact,
      os.path.join(FLAGS.index_result_path,
                   "ent2fact_%d.npz" % FLAGS.max_facts_per_entity))

  search_utils.write_ragged_to_checkpoint(
      "fact2ent", sp_fact2entity,
      os.path.join(FLAGS.index_result_path, "fact_coref.npz"))

  search_utils.write_to_checkpoint(
      "entity_ids", entity_ids, tf.int32,
      os.path.join(FLAGS.index_result_path, "entity_ids"))

  search_utils.write_to_checkpoint(
      "entity_mask", entity_mask, tf.float32,
      os.path.join(FLAGS.index_result_path, "entity_mask"))

      with tf.device("/cpu:0"):
        search_utils.write_to_checkpoint(
            "db_emb_%d" % ns, mention_emb, tf.float32,
            os.path.join(FLAGS.index_result_path,
                         "%s_mention_feats_%d" % (embed_prefix, ns)))

      tf_db = search_utils.load_database(
          db_emb_str + "_%d" % i, var_to_shape_map[db_emb_str + "_%d" % i],
          ckpt_path)

    search_utils.write_to_checkpoint(
        db_emb_str, np_db, tf.float32,
        os.path.join(FLAGS.index_result_path, embed_feats_str))

      with tf.device("/cpu:0"):
        search_utils.write_to_checkpoint(
            "fact_db_emb_%d" % ns, fact_emb, tf.float32,
            os.path.join(FLAGS.index_result_path,
                         "%s_fact_feats_%d" % (embed_prefix, ns)))

# The following is an example of importing Bert_ utils_ Example after V2 invocation
    bert_predictor = bert_utils_v2.BERTPredictor(tokenizer, bert_ckpt)

# The following is an example after the hotpotqa.index is imported.
    sub_para_objs = index_util.get_sub_paras(orig_para, tokenizer,
                                             FLAGS.max_seq_length,
                                             FLAGS.doc_stride, total_sub_paras)

In input_fns.py source file, the input from DrKit model is imported_ Fns.py and called get_tokens_and_mask() function.

from language.labs.drkit import input_fns as input_utils

    (qry_input_ids, qry_input_mask,
     qry_tokens) = input_utils.get_tokens_and_mask(example.question_text,
                                                   tokenizer, max_query_length)

In model_ In the fns.py source file, search from DrKit model is also imported_ Utils.py, and create is called multiple times_ mips_ Searcher() function. In addition, in the source file, the model is also imported_ Fns.py, entity called_ emb(),sparse_ragged_mul(),ensure_values_in_mat(),convert_search_to_vector(),sp_sp_matmul(),rescore_sparse(),aggregate_sparse_indices(),shared_qry_encoder_v2(),layer_qry_encoder(),remove_from_sparse(),batch_multiply() and compute_loss_from_sptensors() function.

from language.labs.drkit import model_fns as model_utils
from language.labs.drkit import search_utils

# The following is how to import search_ Function instance called by utils
  with tf.device("/cpu:0"):
    tf_fact_db, fact_mips_search_fn = search_utils.create_mips_searcher(
        fact_mips_config.ckpt_var_name,
        # [fact_mips_config.num_facts, fact_mips_config.emb_size],
        fact_mips_config.ckpt_path,
        fact_mips_config.num_neighbors,
        local_var_name="scam_init_barrier_fact")

# The following is the import model_ Partial function instances called by FNS
  batch_entity_emb = model_utils.entity_emb(entity_ind, entity_word_ids,
                                            entity_word_masks, word_emb_table,
                                            word_weights)

  sp_mention_vec = model_utils.sparse_ragged_mul(
      batch_entities,
      ent2ment_ind,
      ent2ment_val,
      batch_size,
      mips_config.num_mentions,
      qa_config.sparse_reduce_fn,  # max or sum
      threshold=qa_config.entity_score_threshold,
      fix_values_to_one=qa_config.fix_sparse_to_one)

    if is_training and qa_config.ensure_answer_dense:
      ret_mention_ids = model_utils.ensure_values_in_mat(
          ret_mention_ids, ensure_index, tf.int32)

  dense_mention_vec = model_utils.convert_search_to_vector(
      ret_mention_scs, ret_mention_ids, tf.cast(batch_size, tf.int32),
      mips_config.num_neighbors, mips_config.num_mentions)

    if qa_config.sparse_strategy == "dense_first":
      ret_mention_vec = model_utils.sp_sp_matmul(dense_mention_vec,
                                                 sp_mention_vec)

      with tf.device("/cpu:0"):
        ret_mention_vec = model_utils.rescore_sparse(sp_mention_vec, tf_db,
                                                     scam_qrys)

      uniq_fact_ids, uniq_fact_scs = model_utils.aggregate_sparse_indices(	
          sp_fact_vec.indices, sp_fact_vec.values, sp_fact_vec.dense_shape,	
          "max")	

  qry_seq_emb, word_emb_table, qry_hidden_size = model_utils.shared_qry_encoder_v2(
      qry_input_ids, qry_input_mask, is_training, use_one_hot_embeddings,
      bert_config, qa_config)

Topics: Python NLP