markovify: subreddit and more fixes.

This commit is contained in:
oddluck 2019-12-22 17:16:45 +00:00
parent e186034b3c
commit a4d043ee33
1 changed files with 21 additions and 29 deletions

View File

@ -21,6 +21,7 @@ import re
import json
import markovify
import spacy
from psaw import PushshiftAPI
from ftfy import fix_text
from nltk.tokenize import sent_tokenize
import gc
@ -34,6 +35,7 @@ except ImportError:
_ = lambda x: x
nlp = spacy.load('en_core_web_sm')
api = PushshiftAPI()
CONTRACTION_MAP = {
"ain't": "is not",
@ -188,13 +190,14 @@ class Markovify(callbacks.Plugin):
json.dump(jsondata, outfile)
def add_text(self, channel, text):
text = self.capsents(text)
text = self.expand_contractions(text)
text = fix_text(text)
if self.registryValue('stripURL', channel):
text = re.sub(r'(?i)\b((?:[a-z][\w-]+:(?:/{1,3}|[a-z0-9%])|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))', '', text)
text = re.sub("(^')|('$)|\s'|'\s|[\"(\(\)\[\])]", "", text)
text = self.expand_contractions(text)
text = self.capsents(text)
text = re.sub('<[^<]+?>', '', text)
text = fix_text(text)
text = re.sub("^'|'$|\s'|'\s|[\"()[\]*`:;<>]", "", text)
text = re.sub("\s+", " ", text)
try:
self.model[channel] = markovify.combine(models=[self.model[channel], POSifiedText(text, retain_original=False)])
except KeyError:
@ -223,7 +226,7 @@ class Markovify(callbacks.Plugin):
return
if response and len(response) > 1 and not response.isspace():
response = re.sub(' ([.!?,;:]) ', '\g<1> ', response)
response = re.sub(' ([.!?,])$', '\g<1>', response)
response = re.sub(" ([.!?,'%])$", "\g<1>", response)
response = re.sub('([.?!,])(?=[^\s])', '\g<1> ', response)
response = response.replace(' - ', '-')
return response
@ -251,22 +254,6 @@ class Markovify(callbacks.Plugin):
expanded_text = re.sub("'", "", expanded_text)
return expanded_text
def _subreddit(self, subreddit, latest_timestamp=None):
"""
Downloads the subreddit comments, 500 at a time.
"""
base_url = "https://api.pushshift.io/reddit/comment/search/"
params = {"subreddit": subreddit, "sort": "desc",
"sort_type": "created_utc", "size": 500, "user_removed": False, "mod_removed": False}
if latest_timestamp != None:
params["before"] = latest_timestamp
with requests.get(base_url, params=params) as response:
data = response.json()
self.count += len(data["data"])
self.latest_timestamp = data['data'][-1]["created_utc"]
data = [item['body'] for item in data["data"]]
return data
def doPrivmsg(self, irc, msg):
(channel, message) = msg.args
channel = channel.lower()
@ -331,7 +318,7 @@ class Markovify(callbacks.Plugin):
return None
def subreddit(self, irc, msg, args, channel, optlist, subreddits):
"""[channel] <subreddit_1> [subreddit_2] [subreddit_3] [...etc.]
"""[channel] [--num ####] <subreddit_1> [subreddit_2] [subreddit_3] [...etc.]
Load subreddit comments into channel corpus.
"""
if not channel:
@ -344,14 +331,15 @@ class Markovify(callbacks.Plugin):
max_comments = 500
for subreddit in subreddits.lower().strip().split(' '):
self.latest_timestamp = None
irc.reply("Attempting to retrieve last {0} comments from r/{1}".format(max_comments, subreddit))
self.count = 0
text = ""
tries = 0
data = []
data.extend(self._subreddit(subreddit, self.latest_timestamp))
if data:
gen = api.search_comments(subreddit=subreddit, filter=['body'], limit=max_comments)
if gen:
data = list(gen)
count = len(data)
irc.reply("Retrieved {0} comments from r/{1}.".format(count, subreddit))
for line in data:
line = line.body
if not line.strip() or line.isspace():
continue
if '[removed]' in line:
@ -363,10 +351,14 @@ class Markovify(callbacks.Plugin):
break
if not ends_with_punctuation:
line = line + "."
text += " {}".format(line)
if len(line.strip()) > 1:
text += " {}".format(line)
self.add_text(channel, text)
else:
irc.reply("Error fetching data from r/{}".format(subreddit))
return
self.save_corpus(channel)
irc.reply("Added {0} comments from r/{1}.".format(self.count, subreddit))
irc.reply("Added {0} comments from r/{1} to corpus for channel {2}.".format(count, subreddit, channel))
del data, text
gc.collect()
subreddit = wrap(subreddit, [additional('channel'), getopts({'num':'int'}), 'text'])