aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEugen Rochko <eugen@zeonfederated.com>2021-03-12 07:00:05 +0100
committerEugen Rochko <eugen@zeonfederated.com>2021-03-15 08:03:18 +0100
commit19a64c968749f163ec767d741d90826dd32983c4 (patch)
tree499dfb2655ae4cd29d47c68e2dcc055d439f9cc6
parentb5057c47176fe3170eec148462f97a1e3964b93a (diff)
downloadmastodon-19a64c968749f163ec767d741d90826dd32983c4.tar
mastodon-19a64c968749f163ec767d741d90826dd32983c4.tar.gz
mastodon-19a64c968749f163ec767d741d90826dd32983c4.tar.bz2
mastodon-19a64c968749f163ec767d741d90826dd32983c4.zip
Refactor raw SQL queries to use Arelrefactor-raw-sql-queries
-rw-r--r--Gemfile1
-rw-r--r--Gemfile.lock3
-rw-r--r--app/lib/account_search_query_builder.rb188
-rw-r--r--app/models/account.rb80
-rw-r--r--app/services/account_search_service.rb20
-rw-r--r--spec/lib/account_search_query_builder_spec.rb101
-rw-r--r--spec/models/account_spec.rb119
7 files changed, 301 insertions, 211 deletions
diff --git a/Gemfile b/Gemfile
index a4a2cc91c..4e8e6eb9d 100644
--- a/Gemfile
+++ b/Gemfile
@@ -26,6 +26,7 @@ gem 'streamio-ffmpeg', '~> 3.0'
gem 'blurhash', '~> 0.1'
gem 'active_model_serializers', '~> 0.10'
+gem 'activerecord-cte', '~> 0.1'
gem 'addressable', '~> 2.7'
gem 'bootsnap', '~> 1.6.0', require: false
gem 'browser'
diff --git a/Gemfile.lock b/Gemfile.lock
index b59cfb1f3..8ddf66773 100644
--- a/Gemfile.lock
+++ b/Gemfile.lock
@@ -54,6 +54,8 @@ GEM
activemodel (= 5.2.4.5)
activesupport (= 5.2.4.5)
arel (>= 9.0)
+ activerecord-cte (0.1.1)
+ activerecord
activestorage (5.2.4.5)
actionpack (= 5.2.4.5)
activerecord (= 5.2.4.5)
@@ -695,6 +697,7 @@ PLATFORMS
DEPENDENCIES
active_model_serializers (~> 0.10)
active_record_query_trace (~> 1.8)
+ activerecord-cte (~> 0.1)
addressable (~> 2.7)
annotate (~> 3.1)
aws-sdk-s3 (~> 1.89)
diff --git a/app/lib/account_search_query_builder.rb b/app/lib/account_search_query_builder.rb
new file mode 100644
index 000000000..c757da914
--- /dev/null
+++ b/app/lib/account_search_query_builder.rb
@@ -0,0 +1,188 @@
+# frozen_string_literal: true
+
+class AccountSearchQueryBuilder
+ DISALLOWED_TSQUERY_CHARACTERS = /['?\\:‘’]/.freeze
+
+ LANGUAGE = Arel::Nodes.build_quoted('simple').freeze
+ EMPTY_STRING = Arel::Nodes.build_quoted('').freeze
+ WEIGHT_A = Arel::Nodes.build_quoted('A').freeze
+ WEIGHT_B = Arel::Nodes.build_quoted('B').freeze
+ WEIGHT_C = Arel::Nodes.build_quoted('C').freeze
+
+ FIELDS = {
+ display_name: { weight: WEIGHT_A }.freeze,
+ username: { weight: WEIGHT_B }.freeze,
+ domain: { weight: WEIGHT_C, nullable: true }.freeze,
+ }.freeze
+
+ RANK_NORMALIZATION = 32
+
+ DEFAULT_OPTIONS = {
+ limit: 10,
+ only_following: false,
+ }.freeze
+
+ # @param [String] terms
+ # @param [Hash] options
+ # @option [Account] :account
+ # @option [Boolean] :only_following
+ # @option [Integer] :limit
+ # @option [Integer] :offset
+ def initialize(terms, options = {})
+ @terms = terms
+ @options = DEFAULT_OPTIONS.merge(options)
+ end
+
+ # @return [ActiveRecord::Relation]
+ def build
+ search_scope.tap do |scope|
+ scope.merge!(personalization_scope) if with_account?
+
+ if with_account? && only_following?
+ scope.merge!(only_following_scope)
+ scope.with!(first_degree_definition) # `merge!` does not handle `with`
+ end
+ end
+ end
+
+ # @return [Array<Account>]
+ def results
+ build.to_a
+ end
+
+ private
+
+ def search_scope
+ Account.select(projections)
+ .where(match_condition)
+ .searchable
+ .includes(:account_stat)
+ .order(rank: :desc)
+ .limit(limit)
+ .offset(offset)
+ end
+
+ def personalization_scope
+ join_condition = accounts_table.join(follows_table, Arel::Nodes::OuterJoin)
+ .on(accounts_table.grouping(accounts_table[:id].eq(follows_table[:account_id]).and(follows_table[:target_account_id].eq(account.id))).or(accounts_table.grouping(accounts_table[:id].eq(follows_table[:target_account_id]).and(follows_table[:account_id].eq(account.id)))))
+ .join_sources
+
+ Account.joins(join_condition)
+ .group(accounts_table[:id])
+ end
+
+ def only_following_scope
+ Account.where(accounts_table[:id].in(first_degree_table.project('*')))
+ end
+
+ def first_degree_definition
+ target_account_ids_query = follows_table.project(follows_table[:target_account_id]).where(follows_table[:account_id].eq(account.id))
+ account_id_query = Arel::SelectManager.new.project(account.id)
+
+ Arel::Nodes::As.new(
+ first_degree_table,
+ target_account_ids_query.union(:all, account_id_query)
+ )
+ end
+
+ def projections
+ rank_column = begin
+ if with_account?
+ weighted_tsrank_template.as('rank')
+ else
+ tsrank_template.as('rank')
+ end
+ end
+
+ [all_columns, rank_column]
+ end
+
+ def all_columns
+ accounts_table[Arel.star]
+ end
+
+ def match_condition
+ Arel::Nodes::InfixOperation.new('@@', tsvector_template, tsquery_template)
+ end
+
+ def tsrank_template
+ @tsrank_template ||= Arel::Nodes::NamedFunction.new('ts_rank_cd', [tsvector_template, tsquery_template, RANK_NORMALIZATION])
+ end
+
+ def weighted_tsrank_template
+ @weighted_tsrank_template ||= Arel::Nodes::Multiplication.new(weight, tsrank_template)
+ end
+
+ def weight
+ Arel::Nodes::Addition.new(follows_table[:id].count, 1)
+ end
+
+ def tsvector_template
+ return @tsvector_template if defined?(@tsvector_template)
+
+ vectors = FIELDS.keys.map do |column|
+ options = FIELDS[column]
+
+ vector = accounts_table[column]
+ vector = Arel::Nodes::NamedFunction.new('coalesce', [vector, EMPTY_STRING]) if options[:nullable]
+ vector = Arel::Nodes::NamedFunction.new('to_tsvector', [LANGUAGE, vector])
+
+ Arel::Nodes::NamedFunction.new('setweight', [vector, options[:weight]])
+ end
+
+ @tsvector_template = Arel::Nodes::Grouping.new(vectors.reduce { |memo, vector| Arel::Nodes::Concat.new(memo, vector) })
+ end
+
+ def query_vector
+ @query_vector ||= Arel::Nodes::NamedFunction.new('to_tsquery', [LANGUAGE, tsquery_template])
+ end
+
+ def sanitized_terms
+ @sanitized_terms ||= @terms.gsub(DISALLOWED_TSQUERY_CHARACTERS, ' ')
+ end
+
+ def tsquery_template
+ return @tsquery_template if defined?(@tsquery_template)
+
+ terms = [
+ Arel::Nodes.build_quoted("' "),
+ Arel::Nodes.build_quoted(sanitized_terms),
+ Arel::Nodes.build_quoted(" '"),
+ Arel::Nodes.build_quoted(':*'),
+ ]
+
+ @tsquery_template = Arel::Nodes::NamedFunction.new('to_tsquery', [LANGUAGE, terms.reduce { |memo, term| Arel::Nodes::Concat.new(memo, term) }])
+ end
+
+ def account
+ @options[:account]
+ end
+
+ def with_account?
+ account.present?
+ end
+
+ def limit
+ @options[:limit]
+ end
+
+ def offset
+ @options[:offset]
+ end
+
+ def only_following?
+ @options[:only_following]
+ end
+
+ def accounts_table
+ Account.arel_table
+ end
+
+ def follows_table
+ Follow.arel_table
+ end
+
+ def first_degree_table
+ Arel::Table.new(:first_degree)
+ end
+end
diff --git a/app/models/account.rb b/app/models/account.rb
index d85fd1f6e..a31c2428d 100644
--- a/app/models/account.rb
+++ b/app/models/account.rb
@@ -432,75 +432,6 @@ class Account < ApplicationRecord
DeliveryFailureTracker.without_unavailable(urls)
end
- def search_for(terms, limit = 10, offset = 0)
- textsearch, query = generate_query_for_search(terms)
-
- sql = <<-SQL.squish
- SELECT
- accounts.*,
- ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
- FROM accounts
- WHERE #{query} @@ #{textsearch}
- AND accounts.suspended_at IS NULL
- AND accounts.moved_to_account_id IS NULL
- ORDER BY rank DESC
- LIMIT ? OFFSET ?
- SQL
-
- records = find_by_sql([sql, limit, offset])
- ActiveRecord::Associations::Preloader.new.preload(records, :account_stat)
- records
- end
-
- def advanced_search_for(terms, account, limit = 10, following = false, offset = 0)
- textsearch, query = generate_query_for_search(terms)
-
- if following
- sql = <<-SQL.squish
- WITH first_degree AS (
- SELECT target_account_id
- FROM follows
- WHERE account_id = ?
- UNION ALL
- SELECT ?
- )
- SELECT
- accounts.*,
- (count(f.id) + 1) * ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
- FROM accounts
- LEFT OUTER JOIN follows AS f ON (accounts.id = f.account_id AND f.target_account_id = ?)
- WHERE accounts.id IN (SELECT * FROM first_degree)
- AND #{query} @@ #{textsearch}
- AND accounts.suspended_at IS NULL
- AND accounts.moved_to_account_id IS NULL
- GROUP BY accounts.id
- ORDER BY rank DESC
- LIMIT ? OFFSET ?
- SQL
-
- records = find_by_sql([sql, account.id, account.id, account.id, limit, offset])
- else
- sql = <<-SQL.squish
- SELECT
- accounts.*,
- (count(f.id) + 1) * ts_rank_cd(#{textsearch}, #{query}, 32) AS rank
- FROM accounts
- LEFT OUTER JOIN follows AS f ON (accounts.id = f.account_id AND f.target_account_id = ?) OR (accounts.id = f.target_account_id AND f.account_id = ?)
- WHERE #{query} @@ #{textsearch}
- AND accounts.suspended_at IS NULL
- AND accounts.moved_to_account_id IS NULL
- GROUP BY accounts.id
- ORDER BY rank DESC
- LIMIT ? OFFSET ?
- SQL
-
- records = find_by_sql([sql, account.id, account.id, limit, offset])
- end
-
- ActiveRecord::Associations::Preloader.new.preload(records, :account_stat)
- records
- end
-
def from_text(text)
return [] if text.blank?
@@ -512,19 +443,10 @@ class Account < ApplicationRecord
TagManager.instance.normalize_domain(domain)
end
end
+
EntityCache.instance.mention(username, domain)
end
end
-
- private
-
- def generate_query_for_search(terms)
- terms = Arel.sql(connection.quote(terms.gsub(/['?\\:]/, ' ')))
- textsearch = "(setweight(to_tsvector('simple', accounts.display_name), 'A') || setweight(to_tsvector('simple', accounts.username), 'B') || setweight(to_tsvector('simple', coalesce(accounts.domain, '')), 'C'))"
- query = "to_tsquery('simple', ''' ' || #{terms} || ' ''' || ':*')"
-
- [textsearch, query]
- end
end
def emojis
diff --git a/app/services/account_search_service.rb b/app/services/account_search_service.rb
index 6fe4b6593..3a4a937a3 100644
--- a/app/services/account_search_service.rb
+++ b/app/services/account_search_service.rb
@@ -53,19 +53,13 @@ class AccountSearchService < BaseService
end
def from_database
- if account
- advanced_search_results
- else
- simple_search_results
- end
- end
-
- def advanced_search_results
- Account.advanced_search_for(terms_for_query, account, limit_for_non_exact_results, options[:following], offset)
- end
-
- def simple_search_results
- Account.search_for(terms_for_query, limit_for_non_exact_results, offset)
+ AccountSearchQueryBuilder.new(
+ terms_for_query,
+ account: account,
+ only_following: options[:following],
+ limit: limit_for_non_exact_results,
+ offset: offset
+ ).results
end
def from_elasticsearch
diff --git a/spec/lib/account_search_query_builder_spec.rb b/spec/lib/account_search_query_builder_spec.rb
new file mode 100644
index 000000000..dc59e2784
--- /dev/null
+++ b/spec/lib/account_search_query_builder_spec.rb
@@ -0,0 +1,101 @@
+# frozen_string_literal: true
+
+require 'rails_helper'
+
+describe AccountSearchQueryBuilder do
+ before do
+ Fabricate(
+ :account,
+ display_name: "Missing",
+ username: "missing",
+ domain: "missing.com"
+ )
+ end
+
+ context 'without account' do
+ it 'accepts ?, \, : and space as delimiter' do
+ needle = Fabricate(
+ :account,
+ display_name: 'A & l & i & c & e',
+ username: 'username',
+ domain: 'example.com'
+ )
+
+ results = described_class.new('A?l\i:c e').build.to_a
+ expect(results).to eq [needle]
+ end
+
+ it 'finds accounts with matching display_name' do
+ needle = Fabricate(
+ :account,
+ display_name: "Display Name",
+ username: "username",
+ domain: "example.com"
+ )
+
+ results = described_class.new("display").build.to_a
+ expect(results).to eq [needle]
+ end
+
+ it 'finds accounts with matching username' do
+ needle = Fabricate(
+ :account,
+ display_name: "Display Name",
+ username: "username",
+ domain: "example.com"
+ )
+
+ results = described_class.new("username").build.to_a
+ expect(results).to eq [needle]
+ end
+
+ it 'finds accounts with matching domain' do
+ needle = Fabricate(
+ :account,
+ display_name: "Display Name",
+ username: "username",
+ domain: "example.com"
+ )
+
+ results = described_class.new("example").build.to_a
+ expect(results).to eq [needle]
+ end
+
+ it 'limits by 10 by default' do
+ 11.times.each { Fabricate(:account, display_name: "Display Name") }
+ results = described_class.new("display").build.to_a
+ expect(results.size).to eq 10
+ end
+
+ it 'accepts arbitrary limits' do
+ 2.times.each { Fabricate(:account, display_name: "Display Name") }
+ results = described_class.new("display", limit: 1).build.to_a
+ expect(results.size).to eq 1
+ end
+
+ it 'ranks multiple matches higher' do
+ needles = [
+ { username: "username", display_name: "username" },
+ { display_name: "Display Name", username: "username", domain: "example.com" },
+ ].map(&method(:Fabricate).curry(2).call(:account))
+
+ results = described_class.new("username").build.to_a
+ expect(results).to eq needles
+ end
+ end
+
+ context 'with account' do
+ let(:account) { Fabricate(:account) }
+
+ it 'ranks followed accounts higher' do
+ needle = Fabricate(:account, username: "Matching")
+ followed_needle = Fabricate(:account, username: "Matcher")
+ account.follow!(followed_needle)
+
+ results = described_class.new("match", account: account).build.to_a
+
+ expect(results).to eq [followed_needle, needle]
+ expect(results.first.rank).to be > results.last.rank
+ end
+ end
+end
diff --git a/spec/models/account_spec.rb b/spec/models/account_spec.rb
index 03d6f5fb0..a22640f1f 100644
--- a/spec/models/account_spec.rb
+++ b/spec/models/account_spec.rb
@@ -309,125 +309,6 @@ RSpec.describe Account, type: :model do
end
end
- describe '.search_for' do
- before do
- _missing = Fabricate(
- :account,
- display_name: "Missing",
- username: "missing",
- domain: "missing.com"
- )
- end
-
- it 'accepts ?, \, : and space as delimiter' do
- match = Fabricate(
- :account,
- display_name: 'A & l & i & c & e',
- username: 'username',
- domain: 'example.com'
- )
-
- results = Account.search_for('A?l\i:c e')
- expect(results).to eq [match]
- end
-
- it 'finds accounts with matching display_name' do
- match = Fabricate(
- :account,
- display_name: "Display Name",
- username: "username",
- domain: "example.com"
- )
-
- results = Account.search_for("display")
- expect(results).to eq [match]
- end
-
- it 'finds accounts with matching username' do
- match = Fabricate(
- :account,
- display_name: "Display Name",
- username: "username",
- domain: "example.com"
- )
-
- results = Account.search_for("username")
- expect(results).to eq [match]
- end
-
- it 'finds accounts with matching domain' do
- match = Fabricate(
- :account,
- display_name: "Display Name",
- username: "username",
- domain: "example.com"
- )
-
- results = Account.search_for("example")
- expect(results).to eq [match]
- end
-
- it 'limits by 10 by default' do
- 11.times.each { Fabricate(:account, display_name: "Display Name") }
- results = Account.search_for("display")
- expect(results.size).to eq 10
- end
-
- it 'accepts arbitrary limits' do
- 2.times.each { Fabricate(:account, display_name: "Display Name") }
- results = Account.search_for("display", 1)
- expect(results.size).to eq 1
- end
-
- it 'ranks multiple matches higher' do
- matches = [
- { username: "username", display_name: "username" },
- { display_name: "Display Name", username: "username", domain: "example.com" },
- ].map(&method(:Fabricate).curry(2).call(:account))
-
- results = Account.search_for("username")
- expect(results).to eq matches
- end
- end
-
- describe '.advanced_search_for' do
- it 'accepts ?, \, : and space as delimiter' do
- account = Fabricate(:account)
- match = Fabricate(
- :account,
- display_name: 'A & l & i & c & e',
- username: 'username',
- domain: 'example.com'
- )
-
- results = Account.advanced_search_for('A?l\i:c e', account)
- expect(results).to eq [match]
- end
-
- it 'limits by 10 by default' do
- 11.times { Fabricate(:account, display_name: "Display Name") }
- results = Account.search_for("display")
- expect(results.size).to eq 10
- end
-
- it 'accepts arbitrary limits' do
- 2.times { Fabricate(:account, display_name: "Display Name") }
- results = Account.search_for("display", 1)
- expect(results.size).to eq 1
- end
-
- it 'ranks followed accounts higher' do
- account = Fabricate(:account)
- match = Fabricate(:account, username: "Matching")
- followed_match = Fabricate(:account, username: "Matcher")
- Fabricate(:follow, account: account, target_account: followed_match)
-
- results = Account.advanced_search_for("match", account)
- expect(results).to eq [followed_match, match]
- expect(results.first.rank).to be > results.last.rank
- end
- end
-
describe '#statuses_count' do
subject { Fabricate(:account) }