Seq2SQL: Generating structured queries from Natural Language using Reinforcement Learning (2017)
Paper link: https://arxiv.org/pdf/1709.00103.pdf Authors: Victor Zhong, Caiming Xiong, Richard Socher
Here’s my notes from this paper.
Seq2SQL uses the Seq2Seq model as a starting point
- The output of the seq2seq model is unnecessarily wide for SQL queries. In other words, the paper is only interested in a small subset of the seq2seq output space.
- As a result, the output of Seq2SQL is limited to the union of words in the question, SQL commands and table headers
The paper created its own datasets.
- It scraped wikipedia tables, using some criteria to eliminate small tables
- Used Mechanical Turk to create a natural language dataset. Did this like so:
- Generate automatically badly formed “questions” from a template. Probably things like “how much max price for tomatoes” and “what person most often selected”. The questions all are queries against a HTML table scraped from Wikipedia.
- Get someone to reword them, turning them into natural language queries. They have to change at least 10 characters from the input for it to count.
- Get two people to check the rewording. If both agree that the rewording is valid, keep it, else discard.
It’s unclear from the paper how they got the “true” results for each natural language query. There’s a SQL query that has the ground truth for each natural language query - but no idea how they got this.
The “ground truth” SQL is run to get the “correct” result. This is then used to generate a reward function for the generated sql query.
The paper refers to old models using something called a pointer network. This pointer network uses a sequence of LSTMs to generate the output token by token. Seq2SQL doesn’t use a pointer network.
Seq2SQL is a three part model, utilising the structure inherent in SQL queries. The three parts were aggregation , select and where . Reinforcement learning is only used in the where phrase.
- Aggregation phase was done with multi-layer perceptron. Select phase was done with a LSTM.
The Where clause used “pointer decoder”. There was a problem with training it because where clauses can be in any order. You don’t want to penalise one ordering unnecessarily. Reinforcement learning was used for this task to fix this. Rewards used were -2 if not valid sql, -1 if wrong result, +1 if right result.