markovify: subreddit and more fixes.
This commit is contained in:
parent
e186034b3c
commit
a4d043ee33
|
|
@ -21,6 +21,7 @@ import re
|
|||
import json
|
||||
import markovify
|
||||
import spacy
|
||||
from psaw import PushshiftAPI
|
||||
from ftfy import fix_text
|
||||
from nltk.tokenize import sent_tokenize
|
||||
import gc
|
||||
|
|
@ -34,6 +35,7 @@ except ImportError:
|
|||
_ = lambda x: x
|
||||
|
||||
nlp = spacy.load('en_core_web_sm')
|
||||
api = PushshiftAPI()
|
||||
|
||||
CONTRACTION_MAP = {
|
||||
"ain't": "is not",
|
||||
|
|
@ -188,13 +190,14 @@ class Markovify(callbacks.Plugin):
|
|||
json.dump(jsondata, outfile)
|
||||
|
||||
def add_text(self, channel, text):
|
||||
text = self.capsents(text)
|
||||
text = self.expand_contractions(text)
|
||||
text = fix_text(text)
|
||||
if self.registryValue('stripURL', channel):
|
||||
text = re.sub(r'(?i)\b((?:[a-z][\w-]+:(?:/{1,3}|[a-z0-9%])|www\d{0,3}[.]|[a-z0-9.\-]+[.][a-z]{2,4}/)(?:[^\s()<>]+|\(([^\s()<>]+|(\([^\s()<>]+\)))*\))+(?:\(([^\s()<>]+|(\([^\s()<>]+\)))*\)|[^\s`!()\[\]{};:\'".,<>?«»“”‘’]))', '', text)
|
||||
text = re.sub("(^')|('$)|\s'|'\s|[\"(\(\)\[\])]", "", text)
|
||||
text = self.expand_contractions(text)
|
||||
text = self.capsents(text)
|
||||
text = re.sub('<[^<]+?>', '', text)
|
||||
text = fix_text(text)
|
||||
text = re.sub("^'|'$|\s'|'\s|[\"()[\]*`:;<>]", "", text)
|
||||
text = re.sub("\s+", " ", text)
|
||||
try:
|
||||
self.model[channel] = markovify.combine(models=[self.model[channel], POSifiedText(text, retain_original=False)])
|
||||
except KeyError:
|
||||
|
|
@ -223,7 +226,7 @@ class Markovify(callbacks.Plugin):
|
|||
return
|
||||
if response and len(response) > 1 and not response.isspace():
|
||||
response = re.sub(' ([.!?,;:]) ', '\g<1> ', response)
|
||||
response = re.sub(' ([.!?,])$', '\g<1>', response)
|
||||
response = re.sub(" ([.!?,'%])$", "\g<1>", response)
|
||||
response = re.sub('([.?!,])(?=[^\s])', '\g<1> ', response)
|
||||
response = response.replace(' - ', '-')
|
||||
return response
|
||||
|
|
@ -251,22 +254,6 @@ class Markovify(callbacks.Plugin):
|
|||
expanded_text = re.sub("'", "", expanded_text)
|
||||
return expanded_text
|
||||
|
||||
def _subreddit(self, subreddit, latest_timestamp=None):
|
||||
"""
|
||||
Downloads the subreddit comments, 500 at a time.
|
||||
"""
|
||||
base_url = "https://api.pushshift.io/reddit/comment/search/"
|
||||
params = {"subreddit": subreddit, "sort": "desc",
|
||||
"sort_type": "created_utc", "size": 500, "user_removed": False, "mod_removed": False}
|
||||
if latest_timestamp != None:
|
||||
params["before"] = latest_timestamp
|
||||
with requests.get(base_url, params=params) as response:
|
||||
data = response.json()
|
||||
self.count += len(data["data"])
|
||||
self.latest_timestamp = data['data'][-1]["created_utc"]
|
||||
data = [item['body'] for item in data["data"]]
|
||||
return data
|
||||
|
||||
def doPrivmsg(self, irc, msg):
|
||||
(channel, message) = msg.args
|
||||
channel = channel.lower()
|
||||
|
|
@ -331,7 +318,7 @@ class Markovify(callbacks.Plugin):
|
|||
return None
|
||||
|
||||
def subreddit(self, irc, msg, args, channel, optlist, subreddits):
|
||||
"""[channel] <subreddit_1> [subreddit_2] [subreddit_3] [...etc.]
|
||||
"""[channel] [--num ####] <subreddit_1> [subreddit_2] [subreddit_3] [...etc.]
|
||||
Load subreddit comments into channel corpus.
|
||||
"""
|
||||
if not channel:
|
||||
|
|
@ -344,14 +331,15 @@ class Markovify(callbacks.Plugin):
|
|||
max_comments = 500
|
||||
for subreddit in subreddits.lower().strip().split(' '):
|
||||
self.latest_timestamp = None
|
||||
irc.reply("Attempting to retrieve last {0} comments from r/{1}".format(max_comments, subreddit))
|
||||
self.count = 0
|
||||
text = ""
|
||||
tries = 0
|
||||
data = []
|
||||
data.extend(self._subreddit(subreddit, self.latest_timestamp))
|
||||
if data:
|
||||
gen = api.search_comments(subreddit=subreddit, filter=['body'], limit=max_comments)
|
||||
if gen:
|
||||
data = list(gen)
|
||||
count = len(data)
|
||||
irc.reply("Retrieved {0} comments from r/{1}.".format(count, subreddit))
|
||||
for line in data:
|
||||
line = line.body
|
||||
if not line.strip() or line.isspace():
|
||||
continue
|
||||
if '[removed]' in line:
|
||||
|
|
@ -363,10 +351,14 @@ class Markovify(callbacks.Plugin):
|
|||
break
|
||||
if not ends_with_punctuation:
|
||||
line = line + "."
|
||||
text += " {}".format(line)
|
||||
if len(line.strip()) > 1:
|
||||
text += " {}".format(line)
|
||||
self.add_text(channel, text)
|
||||
else:
|
||||
irc.reply("Error fetching data from r/{}".format(subreddit))
|
||||
return
|
||||
self.save_corpus(channel)
|
||||
irc.reply("Added {0} comments from r/{1}.".format(self.count, subreddit))
|
||||
irc.reply("Added {0} comments from r/{1} to corpus for channel {2}.".format(count, subreddit, channel))
|
||||
del data, text
|
||||
gc.collect()
|
||||
subreddit = wrap(subreddit, [additional('channel'), getopts({'num':'int'}), 'text'])
|
||||
|
|
|
|||
Loading…
Reference in New Issue